diff --git a/LLama/LLamaModel.cs b/LLama/LLamaModel.cs index 2f03c008..f004a782 100644 --- a/LLama/LLamaModel.cs +++ b/LLama/LLamaModel.cs @@ -294,14 +294,10 @@ namespace LLama } } - var candidates = new List(); - candidates.Capacity = n_vocab; + var candidates = new LLamaTokenData[n_vocab]; for (llama_token token_id = 0; token_id < n_vocab; token_id++) - { - candidates.Add(new LLamaTokenData(token_id, logits[token_id], 0.0f)); - } - - LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates.ToArray(), (ulong)candidates.Count, false); + candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f); + LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates); // Apply penalties float nl_logit = logits[NativeApi.llama_token_nl()]; diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs index 65e09564..6e2c4a46 100644 --- a/LLama/Native/LLamaTokenDataArray.cs +++ b/LLama/Native/LLamaTokenDataArray.cs @@ -1,32 +1,80 @@ using System; using System.Buffers; -using System.Collections.Generic; using System.Runtime.InteropServices; -using System.Text; namespace LLama.Native { - [StructLayout(LayoutKind.Sequential)] + /// + /// Contains an array of LLamaTokenData, potentially sorted. + /// public struct LLamaTokenDataArray { - public Memory data; - public ulong size; - [MarshalAs(UnmanagedType.I1)] - public bool sorted; + /// + /// The LLamaTokenData + /// + public readonly Memory data; - public LLamaTokenDataArray(LLamaTokenData[] data, ulong size, bool sorted) + /// + /// Indicates if `data` is sorted + /// + public readonly bool sorted; + + /// + /// Create a new LLamaTokenDataArray + /// + /// + /// + public LLamaTokenDataArray(Memory tokens, bool isSorted = false) { - this.data = data; - this.size = size; - this.sorted = sorted; + data = tokens; + sorted = isSorted; } } + /// + /// Contains a pointer to an array of LLamaTokenData which is pinned in memory. + /// [StructLayout(LayoutKind.Sequential)] public struct LLamaTokenDataArrayNative { + /// + /// A pointer to an array of LlamaTokenData + /// + /// Memory must be pinned in place for all the time this LLamaTokenDataArrayNative is in use public IntPtr data; + + /// + /// Number of LLamaTokenData in the array + /// public ulong size; + + /// + /// Indicates if the items in the array are sorted + /// + [MarshalAs(UnmanagedType.I1)] public bool sorted; + + /// + /// Create a new LLamaTokenDataArrayNative around the data in the LLamaTokenDataArray + /// + /// Data source + /// Created native array + /// A memory handle, pinning the data in place until disposed + public static MemoryHandle Create(LLamaTokenDataArray array, out LLamaTokenDataArrayNative native) + { + var handle = array.data.Pin(); + + unsafe + { + native = new LLamaTokenDataArrayNative + { + data = new IntPtr(handle.Pointer), + size = (ulong)array.data.Length, + sorted = array.sorted + }; + } + + return handle; + } } } diff --git a/LLama/Native/NativeApi.Sampling.cs b/LLama/Native/NativeApi.Sampling.cs index 8c609896..45a9caf0 100644 --- a/LLama/Native/NativeApi.Sampling.cs +++ b/LLama/Native/NativeApi.Sampling.cs @@ -1,11 +1,10 @@ using System; -using System.Collections.Generic; using System.Runtime.InteropServices; -using System.Text; namespace LLama.Native { using llama_token = Int32; + public unsafe partial class NativeApi { /// @@ -17,7 +16,7 @@ namespace LLama.Native /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, IntPtr candidates, llama_token[] last_tokens, ulong last_tokens_size, float penalty); + public static extern void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token[] last_tokens, ulong last_tokens_size, float penalty); /// /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. @@ -29,7 +28,7 @@ namespace LLama.Native /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, IntPtr candidates, llama_token[] last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence); + public static extern void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token[] last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence); /// /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. @@ -37,7 +36,7 @@ namespace LLama.Native /// /// Pointer to LLamaTokenDataArray [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_sample_softmax(SafeLLamaContextHandle ctx, IntPtr candidates); + public static extern void llama_sample_softmax(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates); /// /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 @@ -47,7 +46,7 @@ namespace LLama.Native /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_sample_top_k(SafeLLamaContextHandle ctx, IntPtr candidates, int k, ulong min_keep); + public static extern void llama_sample_top_k(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, int k, ulong min_keep); /// /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 @@ -57,7 +56,7 @@ namespace LLama.Native /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_sample_top_p(SafeLLamaContextHandle ctx, IntPtr candidates, float p, ulong min_keep); + public static extern void llama_sample_top_p(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float p, ulong min_keep); /// /// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. @@ -67,7 +66,7 @@ namespace LLama.Native /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_sample_tail_free(SafeLLamaContextHandle ctx, IntPtr candidates, float z, ulong min_keep); + public static extern void llama_sample_tail_free(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float z, ulong min_keep); /// /// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. @@ -77,10 +76,16 @@ namespace LLama.Native /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_sample_typical(SafeLLamaContextHandle ctx, IntPtr candidates, float p, ulong min_keep); + public static extern void llama_sample_typical(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float p, ulong min_keep); + /// + /// Modify logits by temperature + /// + /// + /// + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_sample_temperature(SafeLLamaContextHandle ctx, IntPtr candidates, float temp); + public static extern void llama_sample_temperature(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float temp); /// /// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. @@ -93,7 +98,7 @@ namespace LLama.Native /// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, IntPtr candidates, float tau, float eta, int m, float* mu); + public static extern llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, int m, float* mu); /// /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. @@ -105,7 +110,7 @@ namespace LLama.Native /// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, IntPtr candidates, float tau, float eta, float* mu); + public static extern llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, float* mu); /// /// Selects the token with the highest probability. @@ -114,7 +119,7 @@ namespace LLama.Native /// Pointer to LLamaTokenDataArray /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, IntPtr candidates); + public static extern llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates); /// /// Randomly selects a token from the candidates based on their probabilities. @@ -123,6 +128,6 @@ namespace LLama.Native /// Pointer to LLamaTokenDataArray /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern llama_token llama_sample_token(SafeLLamaContextHandle ctx, IntPtr candidates); + public static extern llama_token llama_sample_token(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates); } } diff --git a/LLama/Native/SamplingApi.cs b/LLama/Native/SamplingApi.cs index 9a9021ed..f84ac1b1 100644 --- a/LLama/Native/SamplingApi.cs +++ b/LLama/Native/SamplingApi.cs @@ -1,7 +1,4 @@ using System; -using System.Collections.Generic; -using System.Runtime.InteropServices; -using System.Text; namespace LLama.Native { @@ -18,12 +15,8 @@ namespace LLama.Native /// public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, llama_token[] last_tokens, ulong last_tokens_size, float penalty) { - var handle = candidates.data.Pin(); - var st = new LLamaTokenDataArrayNative(); - st.data = new IntPtr(handle.Pointer); - st.size = candidates.size; - st.sorted = candidates.sorted; - NativeApi.llama_sample_repetition_penalty(ctx, new IntPtr(&st), last_tokens, last_tokens_size, penalty); + using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); + NativeApi.llama_sample_repetition_penalty(ctx, ref st, last_tokens, last_tokens_size, penalty); } /// @@ -37,12 +30,8 @@ namespace LLama.Native /// public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, llama_token[] last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence) { - var handle = candidates.data.Pin(); - var st = new LLamaTokenDataArrayNative(); - st.data = new IntPtr(handle.Pointer); - st.size = candidates.size; - st.sorted = candidates.sorted; - NativeApi.llama_sample_frequency_and_presence_penalties(ctx, new IntPtr(&st), last_tokens, last_tokens_size, alpha_frequency, alpha_presence); + using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); + NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, last_tokens, last_tokens_size, alpha_frequency, alpha_presence); } /// @@ -52,12 +41,8 @@ namespace LLama.Native /// Pointer to LLamaTokenDataArray public static void llama_sample_softmax(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) { - var handle = candidates.data.Pin(); - var st = new LLamaTokenDataArrayNative(); - st.data = new IntPtr(handle.Pointer); - st.size = candidates.size; - st.sorted = candidates.sorted; - NativeApi.llama_sample_softmax(ctx, new IntPtr(&st)); + using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); + NativeApi.llama_sample_softmax(ctx, ref st); } /// @@ -69,12 +54,8 @@ namespace LLama.Native /// public static void llama_sample_top_k(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, int k, ulong min_keep) { - var handle = candidates.data.Pin(); - var st = new LLamaTokenDataArrayNative(); - st.data = new IntPtr(handle.Pointer); - st.size = candidates.size; - st.sorted = candidates.sorted; - NativeApi.llama_sample_top_k(ctx, new IntPtr(&st), k, min_keep); + using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); + NativeApi.llama_sample_top_k(ctx, ref st, k, min_keep); } /// @@ -86,12 +67,8 @@ namespace LLama.Native /// public static void llama_sample_top_p(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep) { - var handle = candidates.data.Pin(); - var st = new LLamaTokenDataArrayNative(); - st.data = new IntPtr(handle.Pointer); - st.size = candidates.size; - st.sorted = candidates.sorted; - NativeApi.llama_sample_top_p(ctx, new IntPtr(&st), p, min_keep); + using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); + NativeApi.llama_sample_top_p(ctx, ref st, p, min_keep); } /// @@ -103,12 +80,8 @@ namespace LLama.Native /// public static void llama_sample_tail_free(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float z, ulong min_keep) { - var handle = candidates.data.Pin(); - var st = new LLamaTokenDataArrayNative(); - st.data = new IntPtr(handle.Pointer); - st.size = candidates.size; - st.sorted = candidates.sorted; - NativeApi.llama_sample_tail_free(ctx, new IntPtr(&st), z, min_keep); + using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); + NativeApi.llama_sample_tail_free(ctx, ref st, z, min_keep); } /// @@ -120,22 +93,14 @@ namespace LLama.Native /// public static void llama_sample_typical(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep) { - var handle = candidates.data.Pin(); - var st = new LLamaTokenDataArrayNative(); - st.data = new IntPtr(handle.Pointer); - st.size = candidates.size; - st.sorted = candidates.sorted; - NativeApi.llama_sample_typical(ctx, new IntPtr(&st), p, min_keep); + using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); + NativeApi.llama_sample_typical(ctx, ref st, p, min_keep); } public static void llama_sample_temperature(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float temp) { - var handle = candidates.data.Pin(); - var st = new LLamaTokenDataArrayNative(); - st.data = new IntPtr(handle.Pointer); - st.size = candidates.size; - st.sorted = candidates.sorted; - NativeApi.llama_sample_temperature(ctx, new IntPtr(&st), temp); + using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); + NativeApi.llama_sample_temperature(ctx, ref st, temp); } /// @@ -150,17 +115,11 @@ namespace LLama.Native /// public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu) { - var handle = candidates.data.Pin(); - var st = new LLamaTokenDataArrayNative(); - st.data = new IntPtr(handle.Pointer); - st.size = candidates.size; - st.sorted = candidates.sorted; - llama_token res; + using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); fixed(float* pmu = &mu) { - res = NativeApi.llama_sample_token_mirostat(ctx, new IntPtr(&st), tau, eta, m, pmu); + return NativeApi.llama_sample_token_mirostat(ctx, ref st, tau, eta, m, pmu); } - return res; } /// @@ -174,17 +133,11 @@ namespace LLama.Native /// public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu) { - var handle = candidates.data.Pin(); - var st = new LLamaTokenDataArrayNative(); - st.data = new IntPtr(handle.Pointer); - st.size = candidates.size; - st.sorted = candidates.sorted; - llama_token res; + using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); fixed (float* pmu = &mu) { - res = NativeApi.llama_sample_token_mirostat_v2(ctx, new IntPtr(&st), tau, eta, pmu); + return NativeApi.llama_sample_token_mirostat_v2(ctx, ref st, tau, eta, pmu); } - return res; } /// @@ -195,12 +148,8 @@ namespace LLama.Native /// public static llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) { - var handle = candidates.data.Pin(); - var st = new LLamaTokenDataArrayNative(); - st.data = new IntPtr(handle.Pointer); - st.size = candidates.size; - st.sorted = candidates.sorted; - return NativeApi.llama_sample_token_greedy(ctx, new IntPtr(&st)); + using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); + return NativeApi.llama_sample_token_greedy(ctx, ref st); } /// @@ -211,12 +160,8 @@ namespace LLama.Native /// public static llama_token llama_sample_token(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) { - var handle = candidates.data.Pin(); - var st = new LLamaTokenDataArrayNative(); - st.data = new IntPtr(handle.Pointer); - st.size = candidates.size; - st.sorted = candidates.sorted; - return NativeApi.llama_sample_token(ctx, new IntPtr(&st)); + using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); + return NativeApi.llama_sample_token(ctx, ref st); } } } diff --git a/LLama/OldVersion/LLamaModel.cs b/LLama/OldVersion/LLamaModel.cs index 46fc7e63..bf400ba4 100644 --- a/LLama/OldVersion/LLamaModel.cs +++ b/LLama/OldVersion/LLamaModel.cs @@ -632,14 +632,10 @@ namespace LLama.OldVersion logits[key] += value; } - var candidates = new List(); - candidates.Capacity = n_vocab; + var candidates = new LLamaTokenData[n_vocab]; for (llama_token token_id = 0; token_id < n_vocab; token_id++) - { - candidates.Add(new LLamaTokenData(token_id, logits[token_id], 0.0f)); - } - - LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates.ToArray(), (ulong)candidates.Count, false); + candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f); + LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates); // Apply penalties float nl_logit = logits[NativeApi.llama_token_nl()];