diff --git a/LLama.Examples/NewVersion/ChatSessionStripRoleName.cs b/LLama.Examples/NewVersion/ChatSessionStripRoleName.cs index 6402e360..46c0505e 100644 --- a/LLama.Examples/NewVersion/ChatSessionStripRoleName.cs +++ b/LLama.Examples/NewVersion/ChatSessionStripRoleName.cs @@ -1,9 +1,5 @@ using LLama.Common; -using System; -using System.Collections.Generic; -using System.Linq; using System.Text; -using System.Threading.Tasks; namespace LLama.Examples.NewVersion { @@ -12,15 +8,22 @@ namespace LLama.Examples.NewVersion public static void Run() { Console.Write("Please input your model path: "); - string modelPath = Console.ReadLine(); + var modelPath = Console.ReadLine(); var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); - InteractiveExecutor ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); - ChatSession session = new ChatSession(ex).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" }, redundancyLength: 8)); + + var parameters = new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5); + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters); + var executor = new InteractiveExecutor(context); + + var session = new ChatSession(executor).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" }, redundancyLength: 8)); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The chat session has started. The role names won't be printed."); Console.ForegroundColor = ConsoleColor.White; + // show the prompt + Console.Write(prompt); while (true) { foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List { "User:" } })) diff --git a/LLama.Examples/NewVersion/ChatSessionWithRoleName.cs b/LLama.Examples/NewVersion/ChatSessionWithRoleName.cs index d1cbf34b..5e155252 100644 --- a/LLama.Examples/NewVersion/ChatSessionWithRoleName.cs +++ b/LLama.Examples/NewVersion/ChatSessionWithRoleName.cs @@ -1,9 +1,5 @@ using LLama.Common; -using System; -using System.Collections.Generic; -using System.Linq; using System.Text; -using System.Threading.Tasks; namespace LLama.Examples.NewVersion { @@ -12,10 +8,15 @@ namespace LLama.Examples.NewVersion public static void Run() { Console.Write("Please input your model path: "); - string modelPath = Console.ReadLine(); + var modelPath = Console.ReadLine(); var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); - InteractiveExecutor ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); - ChatSession session = new ChatSession(ex); // The only change is to remove the transform for the output text stream. + + var parameters = new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5); + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters); + var executor = new InteractiveExecutor(context); + + var session = new ChatSession(executor); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The chat session has started. In this example, the prompt is printed for better visual result."); diff --git a/LLama.Examples/NewVersion/GetEmbeddings.cs b/LLama.Examples/NewVersion/GetEmbeddings.cs index ed12f868..516d2da7 100644 --- a/LLama.Examples/NewVersion/GetEmbeddings.cs +++ b/LLama.Examples/NewVersion/GetEmbeddings.cs @@ -1,9 +1,4 @@ using LLama.Common; -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; namespace LLama.Examples.NewVersion { @@ -12,7 +7,7 @@ namespace LLama.Examples.NewVersion public static void Run() { Console.Write("Please input your model path: "); - string modelPath = Console.ReadLine(); + var modelPath = Console.ReadLine(); var embedder = new LLamaEmbedder(new ModelParams(modelPath)); while (true) diff --git a/LLama.Examples/NewVersion/InstructModeExecute.cs b/LLama.Examples/NewVersion/InstructModeExecute.cs index f81f2f58..d5b31416 100644 --- a/LLama.Examples/NewVersion/InstructModeExecute.cs +++ b/LLama.Examples/NewVersion/InstructModeExecute.cs @@ -1,9 +1,5 @@ using LLama.Common; -using System; -using System.Collections.Generic; -using System.Linq; using System.Text; -using System.Threading.Tasks; namespace LLama.Examples.NewVersion { @@ -12,10 +8,13 @@ namespace LLama.Examples.NewVersion public static void Run() { Console.Write("Please input your model path: "); - string modelPath = Console.ReadLine(); + var modelPath = Console.ReadLine(); var prompt = File.ReadAllText("Assets/dan.txt").Trim(); - InstructExecutor ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 1024))); + var parameters = new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5); + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters); + var executor = new InstructExecutor(context); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example, the LLM will follow your instructions. For example, you can input \"Write a story about a fox who want to " + @@ -26,7 +25,7 @@ namespace LLama.Examples.NewVersion while (true) { - foreach (var text in ex.Infer(prompt, inferenceParams)) + foreach (var text in executor.Infer(prompt, inferenceParams)) { Console.Write(text); } diff --git a/LLama.Examples/NewVersion/InteractiveModeExecute.cs b/LLama.Examples/NewVersion/InteractiveModeExecute.cs index aaacabbe..cc7c2891 100644 --- a/LLama.Examples/NewVersion/InteractiveModeExecute.cs +++ b/LLama.Examples/NewVersion/InteractiveModeExecute.cs @@ -1,21 +1,20 @@ using LLama.Common; -using System; -using System.Collections.Generic; -using System.Linq; using System.Text; -using System.Threading.Tasks; namespace LLama.Examples.NewVersion { public class InteractiveModeExecute { - public async static Task Run() + public static async Task Run() { Console.Write("Please input your model path: "); - string modelPath = Console.ReadLine(); - var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); + var modelPath = Console.ReadLine(); + var prompt = (await File.ReadAllTextAsync("Assets/chat-with-bob.txt")).Trim(); - InteractiveExecutor ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 256))); + var parameters = new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5); + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters); + var ex = new InteractiveExecutor(context); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to 128 and the context size is 256. (an example for small scale usage)"); diff --git a/LLama.Examples/NewVersion/LoadAndSaveSession.cs b/LLama.Examples/NewVersion/LoadAndSaveSession.cs index cbed9179..9ca07de0 100644 --- a/LLama.Examples/NewVersion/LoadAndSaveSession.cs +++ b/LLama.Examples/NewVersion/LoadAndSaveSession.cs @@ -1,10 +1,5 @@ using LLama.Common; -using LLama.OldVersion; -using System; -using System.Collections.Generic; -using System.Linq; using System.Text; -using System.Threading.Tasks; namespace LLama.Examples.NewVersion { @@ -13,10 +8,15 @@ namespace LLama.Examples.NewVersion public static void Run() { Console.Write("Please input your model path: "); - string modelPath = Console.ReadLine(); + var modelPath = Console.ReadLine(); var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); - InteractiveExecutor ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); - ChatSession session = new ChatSession(ex); // The only change is to remove the transform for the output text stream. + + var parameters = new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5); + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters); + var ex = new InteractiveExecutor(context); + + var session = new ChatSession(ex); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The chat session has started. In this example, the prompt is printed for better visual result. Input \"save\" to save and reload the session."); diff --git a/LLama.Examples/NewVersion/LoadAndSaveState.cs b/LLama.Examples/NewVersion/LoadAndSaveState.cs index 15f2f815..d72fade8 100644 --- a/LLama.Examples/NewVersion/LoadAndSaveState.cs +++ b/LLama.Examples/NewVersion/LoadAndSaveState.cs @@ -1,9 +1,5 @@ using LLama.Common; -using System; -using System.Collections.Generic; -using System.Linq; using System.Text; -using System.Threading.Tasks; namespace LLama.Examples.NewVersion { @@ -12,10 +8,13 @@ namespace LLama.Examples.NewVersion public static void Run() { Console.Write("Please input your model path: "); - string modelPath = Console.ReadLine(); + var modelPath = Console.ReadLine(); var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); - InteractiveExecutor ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 256))); + var parameters = new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5); + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters); + var ex = new InteractiveExecutor(context); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to 64 and the context size is 256. (an example for small scale usage)"); @@ -47,9 +46,9 @@ namespace LLama.Examples.NewVersion Console.WriteLine("All states saved!"); Console.ForegroundColor = ConsoleColor.White; - var model = ex.Context; - model.LoadState(modelStatePath); - ex = new InteractiveExecutor(model); + var ctx = ex.Context; + ctx.LoadState(modelStatePath); + ex = new InteractiveExecutor(ctx); ex.LoadState(executorStatePath); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("Loaded state!"); diff --git a/LLama.Examples/NewVersion/QuantizeModel.cs b/LLama.Examples/NewVersion/QuantizeModel.cs index a5ad81d8..71966af8 100644 --- a/LLama.Examples/NewVersion/QuantizeModel.cs +++ b/LLama.Examples/NewVersion/QuantizeModel.cs @@ -1,11 +1,4 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading; -using System.Threading.Tasks; - -namespace LLama.Examples.NewVersion +namespace LLama.Examples.NewVersion { public class QuantizeModel { @@ -13,13 +6,16 @@ namespace LLama.Examples.NewVersion { Console.Write("Please input your original model path: "); var inputPath = Console.ReadLine(); + Console.Write("Please input your output model path: "); var outputPath = Console.ReadLine(); + Console.Write("Please input the quantize type (one of q4_0, q4_1, q5_0, q5_1, q8_0): "); var quantizeType = Console.ReadLine(); + if (LLamaQuantizer.Quantize(inputPath, outputPath, quantizeType)) { - Console.WriteLine("Quantization succeed!"); + Console.WriteLine("Quantization succeeded!"); } else { diff --git a/LLama.Examples/NewVersion/StatelessModeExecute.cs b/LLama.Examples/NewVersion/StatelessModeExecute.cs index 8ff2c0a1..7f59e73e 100644 --- a/LLama.Examples/NewVersion/StatelessModeExecute.cs +++ b/LLama.Examples/NewVersion/StatelessModeExecute.cs @@ -1,9 +1,4 @@ using LLama.Common; -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; namespace LLama.Examples.NewVersion { @@ -12,9 +7,11 @@ namespace LLama.Examples.NewVersion public static void Run() { Console.Write("Please input your model path: "); - string modelPath = Console.ReadLine(); + var modelPath = Console.ReadLine(); - StatelessExecutor ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 256))); + var parameters = new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5); + using var model = LLamaWeights.LoadFromFile(parameters); + var ex = new StatelessExecutor(model, parameters); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example, the inference is an one-time job. That says, the previous input and response has " + @@ -29,10 +26,10 @@ namespace LLama.Examples.NewVersion { Console.Write("\nQuestion: "); Console.ForegroundColor = ConsoleColor.Green; - string prompt = Console.ReadLine(); - Console.ForegroundColor = ConsoleColor.White; + var prompt = Console.ReadLine(); + Console.ForegroundColor = ConsoleColor.White; Console.Write("Answer: "); - prompt = $"Question: {prompt.Trim()} Answer: "; + prompt = $"Question: {prompt?.Trim()} Answer: "; foreach (var text in ex.Infer(prompt, inferenceParams)) { Console.Write(text); diff --git a/LLama.Examples/NewVersion/TalkToYourself.cs b/LLama.Examples/NewVersion/TalkToYourself.cs index 35a65241..309d5654 100644 --- a/LLama.Examples/NewVersion/TalkToYourself.cs +++ b/LLama.Examples/NewVersion/TalkToYourself.cs @@ -20,9 +20,9 @@ namespace LLama.Examples.NewVersion using var weights = LLamaWeights.LoadFromFile(@params); // Create 2 contexts sharing the same weights - using var aliceCtx = weights.CreateContext(@params, Encoding.UTF8); + using var aliceCtx = weights.CreateContext(@params); var alice = new InteractiveExecutor(aliceCtx); - using var bobCtx = weights.CreateContext(@params, Encoding.UTF8); + using var bobCtx = weights.CreateContext(@params); var bob = new InteractiveExecutor(bobCtx); // Initial alice prompt diff --git a/LLama.Examples/NewVersion/TestRunner.cs b/LLama.Examples/NewVersion/TestRunner.cs index c90bc78d..6cc3f3da 100644 --- a/LLama.Examples/NewVersion/TestRunner.cs +++ b/LLama.Examples/NewVersion/TestRunner.cs @@ -1,10 +1,4 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; - -namespace LLama.Examples.NewVersion +namespace LLama.Examples.NewVersion { public class NewVersionTestRunner { diff --git a/LLama.Unittest/LLamaContextTests.cs b/LLama.Unittest/LLamaContextTests.cs index a34f58cb..1c5fa952 100644 --- a/LLama.Unittest/LLamaContextTests.cs +++ b/LLama.Unittest/LLamaContextTests.cs @@ -16,7 +16,7 @@ namespace LLama.Unittest ContextSize = 768, }; _weights = LLamaWeights.LoadFromFile(@params); - _context = _weights.CreateContext(@params, Encoding.UTF8); + _context = _weights.CreateContext(@params); } public void Dispose() diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs new file mode 100644 index 00000000..37031da3 --- /dev/null +++ b/LLama.Unittest/StatelessExecutorTest.cs @@ -0,0 +1,70 @@ +using LLama.Common; +using Xunit.Abstractions; + +namespace LLama.Unittest +{ + public class StatelessExecutorTest + : IDisposable + { + private readonly ITestOutputHelper _testOutputHelper; + private readonly LLamaWeights _weights; + private readonly ModelParams _params; + + public StatelessExecutorTest(ITestOutputHelper testOutputHelper) + { + _testOutputHelper = testOutputHelper; + _params = new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin") + { + ContextSize = 60, + Seed = 1754 + }; + _weights = LLamaWeights.LoadFromFile(_params); + } + + public void Dispose() + { + _weights.Dispose(); + } + + [Fact] + public void Stateless() + { + var executor = new StatelessExecutor(_weights, _params); + + const string question = "Question. what is a cat?\nAnswer: "; + var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } }; + + var result1 = string.Join("", executor.Infer(question, @params)); + var result2 = string.Join("", executor.Infer(question, @params)); + + _testOutputHelper.WriteLine(result1); + + // Check that it produced the exact same result both times + Assert.Equal(result1, result2); + } + + [Fact] + public void OutOfContext() + { + var executor = new StatelessExecutor(_weights, _params); + + const string question = " Question. why is a cat the best pet?\nAnswer: "; + + // The context size is set to 60. Generate more than that, forcing it to generate a coherent response + // with a modified context + var @params = new InferenceParams() + { + MaxTokens = 100, + TokensKeep = question.Length, + }; + + var result1 = string.Join("", executor.Infer(question, @params)); + var result2 = string.Join("", executor.Infer(question, @params)); + + _testOutputHelper.WriteLine(result1); + + // Check that it produced the exact same result both times + Assert.Equal(result1, result2); + } + } +} \ No newline at end of file diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs index e8b89dee..3f5a3f0c 100644 --- a/LLama.Web/Common/ModelOptions.cs +++ b/LLama.Web/Common/ModelOptions.cs @@ -2,7 +2,8 @@ namespace LLama.Web.Common { - public class ModelOptions : IModelParams + public class ModelOptions + : IModelParams { public string Name { get; set; } @@ -111,5 +112,9 @@ namespace LLama.Web.Common /// public bool MulMatQ { get; set; } - } + /// + /// The encoding to use for models + /// + public string Encoding { get; set; } = "UTF-8"; + } } diff --git a/LLama/Abstractions/IModelParams.cs b/LLama/Abstractions/IModelParams.cs index fdc91152..64a0125b 100644 --- a/LLama/Abstractions/IModelParams.cs +++ b/LLama/Abstractions/IModelParams.cs @@ -1,6 +1,4 @@ -using System; - -namespace LLama.Abstractions +namespace LLama.Abstractions { public interface IModelParams { @@ -119,5 +117,10 @@ namespace LLama.Abstractions /// Use experimental mul_mat_q kernels /// bool MulMatQ { get; set; } + + /// + /// The encoding to use for models + /// + string Encoding { get; set; } } } \ No newline at end of file diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs index 5cb81078..c0741abe 100644 --- a/LLama/Common/ModelParams.cs +++ b/LLama/Common/ModelParams.cs @@ -1,5 +1,6 @@ using LLama.Abstractions; using System; +using System.Text; namespace LLama.Common { @@ -111,34 +112,41 @@ namespace LLama.Common /// public bool MulMatQ { get; set; } - /// - /// - /// - /// The model path. - /// Model context size (n_ctx) - /// Number of layers to run in VRAM / GPU memory (n_gpu_layers) - /// Seed for the random number generator (seed) - /// Whether to use f16 instead of f32 for memory kv (memory_f16) - /// Whether to use mmap for faster loads (use_mmap) - /// Whether to use mlock to keep model in memory (use_mlock) - /// Thether to compute perplexity over the prompt (perplexity) - /// Lora adapter path (lora_adapter) - /// Base model path for the lora adapter (lora_base) - /// Number of threads (-1 = autodetect) (n_threads) - /// Batch size for prompt processing (must be >=32 to use BLAS) (n_batch) - /// Whether to convert eos to newline during the inference. - /// Whether to use embedding mode. (embedding) Note that if this is set to true, The LLamaModel won't produce text response anymore. - /// Grouped-Query Attention - /// RMS Norm Epsilon - /// RoPE base frequency. - /// RoPE frequency scaling factor - /// Use experimental mul_mat_q kernels - public ModelParams(string modelPath, int contextSize = 512, int gpuLayerCount = 20, - int seed = 1337, bool useFp16Memory = true, - bool useMemorymap = true, bool useMemoryLock = false, bool perplexity = false, - string loraAdapter = "", string loraBase = "", int threads = -1, int batchSize = 512, - bool convertEosToNewLine = false, bool embeddingMode = false, - int gqa = 1, float rmsNormEps = 5e-6f, float rope_freq_base = 10000.0f, float rope_freq_scale = 1f, bool muMatQ = false) + /// + /// The encoding to use to convert text for the model + /// + public string Encoding { get; set; } = "UTF-8"; + + /// + /// + /// + /// The model path. + /// Model context size (n_ctx) + /// Number of layers to run in VRAM / GPU memory (n_gpu_layers) + /// Seed for the random number generator (seed) + /// Whether to use f16 instead of f32 for memory kv (memory_f16) + /// Whether to use mmap for faster loads (use_mmap) + /// Whether to use mlock to keep model in memory (use_mlock) + /// Thether to compute perplexity over the prompt (perplexity) + /// Lora adapter path (lora_adapter) + /// Base model path for the lora adapter (lora_base) + /// Number of threads (-1 = autodetect) (n_threads) + /// Batch size for prompt processing (must be >=32 to use BLAS) (n_batch) + /// Whether to convert eos to newline during the inference. + /// Whether to use embedding mode. (embedding) Note that if this is set to true, The LLamaModel won't produce text response anymore. + /// Grouped-Query Attention + /// RMS Norm Epsilon + /// RoPE base frequency. + /// RoPE frequency scaling factor + /// Use experimental mul_mat_q kernels + /// The encoding to use to convert text for the model + public ModelParams(string modelPath, int contextSize = 512, int gpuLayerCount = 20, + int seed = 1337, bool useFp16Memory = true, + bool useMemorymap = true, bool useMemoryLock = false, bool perplexity = false, + string loraAdapter = "", string loraBase = "", int threads = -1, int batchSize = 512, + bool convertEosToNewLine = false, bool embeddingMode = false, + int groupedQueryAttention = 1, float rmsNormEps = 5e-6f, float ropeFreqBase = 10000.0f, float ropeFreqScale = 1f, bool muMatQ = false, + string encoding = "UTF-8") { ContextSize = contextSize; GpuLayerCount = gpuLayerCount; @@ -154,11 +162,12 @@ namespace LLama.Common BatchSize = batchSize; ConvertEosToNewLine = convertEosToNewLine; EmbeddingMode = embeddingMode; - GroupedQueryAttention = gqa; + GroupedQueryAttention = groupedQueryAttention; RmsNormEpsilon = rmsNormEps; - RopeFrequencyBase = rope_freq_base; - RopeFrequencyScale = rope_freq_scale; + RopeFrequencyBase = ropeFreqBase; + RopeFrequencyScale = ropeFreqScale; MulMatQ = muMatQ; - } + Encoding = encoding; + } } } diff --git a/LLama/Extensions/ListExtensions.cs b/LLama/Extensions/ListExtensions.cs new file mode 100644 index 00000000..c78d311c --- /dev/null +++ b/LLama/Extensions/ListExtensions.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; + +namespace LLama.Extensions +{ + internal static class ListExtensions + { + public static void AddRangeSpan(this List list, ReadOnlySpan span) + { + for (var i = 0; i < span.Length; i++) + list.Add(span[i]); + } + } +} diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 9c053d37..9501d570 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -1,6 +1,7 @@ using LLama.Exceptions; using LLama.Native; using System; +using System.Buffers; using System.Collections.Generic; using System.Linq; using System.Text; @@ -9,7 +10,6 @@ using System.IO.MemoryMappedFiles; using LLama.Common; using System.Runtime.InteropServices; using LLama.Extensions; -using Microsoft.Win32.SafeHandles; using LLama.Abstractions; namespace LLama @@ -61,26 +61,25 @@ namespace LLama /// /// /// Model params. - /// Encoding to deal with text input. /// The logger. [Obsolete("Use the LLamaWeights.CreateContext instead")] - public LLamaContext(IModelParams @params, string encoding = "UTF-8", ILLamaLogger? logger = null) + public LLamaContext(IModelParams @params, ILLamaLogger? logger = null) { Params = @params; _logger = logger; - _encoding = Encoding.GetEncoding(encoding); + _encoding = Encoding.GetEncoding(@params.Encoding); _logger?.Log(nameof(LLamaContext), $"Initializing LLama model with params: {this.Params}", ILLamaLogger.LogLevel.Info); _ctx = Utils.InitLLamaContextFromModelParams(Params); } - internal LLamaContext(SafeLLamaContextHandle nativeContext, IModelParams @params, Encoding encoding, ILLamaLogger? logger = null) + internal LLamaContext(SafeLLamaContextHandle nativeContext, IModelParams @params, ILLamaLogger? logger = null) { Params = @params; _logger = logger; - _encoding = encoding; + _encoding = Encoding.GetEncoding(@params.Encoding); _ctx = nativeContext; } @@ -89,10 +88,9 @@ namespace LLama /// /// /// - /// /// /// - public LLamaContext(LLamaWeights model, IModelParams @params, Encoding encoding, ILLamaLogger? logger = null) + public LLamaContext(LLamaWeights model, IModelParams @params, ILLamaLogger? logger = null) { if (model.NativeHandle.IsClosed) throw new ObjectDisposedException("Cannot create context, model weights have been disposed"); @@ -100,7 +98,7 @@ namespace LLama Params = @params; _logger = logger; - _encoding = encoding; + _encoding = Encoding.GetEncoding(@params.Encoding); using var pin = @params.ToLlamaContextParams(out var lparams); _ctx = SafeLLamaContextHandle.Create(model.NativeHandle, lparams); @@ -115,7 +113,7 @@ namespace LLama using var pin = Params.ToLlamaContextParams(out var lparams); // Create a blank new context for the model - var ctx = new LLamaContext(SafeLLamaContextHandle.Create(NativeHandle.ModelHandle, lparams), Params, _encoding); + var ctx = new LLamaContext(SafeLLamaContextHandle.Create(NativeHandle.ModelHandle, lparams), Params); // Copy across the state using var state = GetState(); @@ -398,6 +396,7 @@ namespace LLama return candidates_p; } + #region eval overloads /// /// /// @@ -405,18 +404,72 @@ namespace LLama /// /// The updated `pastTokensCount`. /// - public llama_token Eval(llama_token[] tokens, llama_token pastTokensCount) + public int Eval(llama_token[] tokens, llama_token pastTokensCount) { - int total = tokens.Length; - for(int i = 0; i < total; i += Params.BatchSize) + return Eval(tokens.AsSpan(), pastTokensCount); + } + + /// + /// + /// + /// + /// + /// The updated `pastTokensCount`. + /// + public int Eval(List tokens, llama_token pastTokensCount) + { +#if NET5_0_OR_GREATER + var span = CollectionsMarshal.AsSpan(tokens); + return Eval(span, pastTokensCount); +#else + // on netstandard2.0 we can't use CollectionsMarshal to get directly at the internal memory of + // the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't + // avoid the copying. + + var rented = ArrayPool.Shared.Rent(tokens.Count); + try + { + tokens.CopyTo(rented, 0); + return Eval(rented, pastTokensCount); + } + finally + { + ArrayPool.Shared.Return(rented); + } +#endif + } + + /// + /// + /// + /// + /// + /// The updated `pastTokensCount`. + /// + public int Eval(ReadOnlyMemory tokens, llama_token pastTokensCount) + { + return Eval(tokens.Span, pastTokensCount); + } + + /// + /// + /// + /// + /// + /// The updated `pastTokensCount`. + /// + public int Eval(ReadOnlySpan tokens, llama_token pastTokensCount) + { + var total = tokens.Length; + for(var i = 0; i < total; i += Params.BatchSize) { - int n_eval = total - i; - if(n_eval > Params.BatchSize) + var n_eval = total - i; + if (n_eval > Params.BatchSize) { n_eval = Params.BatchSize; } - if (!_ctx.Eval(tokens.AsMemory(i, n_eval), pastTokensCount, Params.Threads)) + if (!_ctx.Eval(tokens.Slice(i, n_eval), pastTokensCount, Params.Threads)) { _logger?.Log(nameof(LLamaContext), "Failed to eval.", ILLamaLogger.LogLevel.Error); throw new RuntimeError("Failed to eval."); @@ -426,6 +479,7 @@ namespace LLama } return pastTokensCount; } +#endregion internal IEnumerable GenerateResult(IEnumerable ids) { @@ -433,6 +487,16 @@ namespace LLama yield return _ctx.TokenToString(id, _encoding); } + /// + /// Convert a token into a string + /// + /// + /// + public string TokenToString(llama_token token) + { + return NativeHandle.TokenToString(token, Encoding); + } + /// public virtual void Dispose() { diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index de708785..3d5a3356 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -189,7 +189,7 @@ namespace LLama } TryReuseMathingPrefix(); - _pastTokensCount = Context.Eval(_embeds.ToArray(), _pastTokensCount); + _pastTokensCount = Context.Eval(_embeds, _pastTokensCount); if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) { diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index e65c6f19..a738e981 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -178,7 +178,7 @@ namespace LLama } TryReuseMathingPrefix(); - _pastTokensCount = Context.Eval(_embeds.ToArray(), _pastTokensCount); + _pastTokensCount = Context.Eval(_embeds, _pastTokensCount); if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) { diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index c2fe4985..446571b0 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -1,125 +1,159 @@ using LLama.Abstractions; using LLama.Common; -using LLama.Native; using System; using System.Collections.Generic; using System.Linq; using System.Runtime.CompilerServices; +using System.Text; using System.Threading; namespace LLama { using llama_token = Int32; + /// /// This executor infer the input as one-time job. Previous inputs won't impact on the /// response to current input. /// - public class StatelessExecutor : ILLamaExecutor + public class StatelessExecutor + : ILLamaExecutor { - private LLamaContext _context; - private LLamaContext.State _originalState; + private readonly LLamaWeights _weights; + private readonly IModelParams _params; + /// /// The context used by the executor when running the inference. /// - public LLamaContext Context => _context; + public LLamaContext Context { get; private set; } + + /// + /// Create a new stateless executor which will use the given model + /// + /// + /// + public StatelessExecutor(LLamaWeights weights, IModelParams @params) + { + _weights = weights; + _params = @params; + + Context = _weights.CreateContext(_params); + Context.Dispose(); + } + /// - /// + /// Create a new stateless executor which will use the model used to create the given context /// - /// The LLama model. + /// + [Obsolete("Use the constructor which automatically creates contexts using the LLamaWeights")] public StatelessExecutor(LLamaContext context) { - _context = context; - - var tokens = context.Tokenize(" ", true).ToArray(); - _context.NativeHandle.Eval(tokens.AsMemory(0, tokens.Length), 0, _context.Params.Threads); - _originalState = context.GetState(); + _weights = new LLamaWeights(context.NativeHandle.ModelHandle, Encoding.GetEncoding(context.Params.Encoding)); + _params = context.Params; + + Context = _weights.CreateContext(_params); + Context.Dispose(); } /// public IEnumerable Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) { - cancellationToken.ThrowIfCancellationRequested(); - int n_past = 1; - if(inferenceParams is null) - { - inferenceParams = new InferenceParams(); - } - List lastTokens = new(inferenceParams.RepeatLastTokensCount); - for(int i = 0; i < lastTokens.Count; i++) + using var context = _weights.CreateContext(_params); + Context = context; + + if (!Context.NativeHandle.IsClosed) + Context.Dispose(); + Context = _weights.CreateContext(Context.Params); + + if (inferenceParams != null) { - lastTokens[i] = 0; + if (inferenceParams.TokensKeep > Context.ContextSize) + throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})"); } - List tokens = _context.Tokenize(text, true).ToList(); - int n_prompt_tokens = tokens.Count; - _context.NativeHandle.Eval(tokens.ToArray().AsMemory(0, n_prompt_tokens), n_past, _context.Params.Threads); + cancellationToken.ThrowIfCancellationRequested(); + + var antiprompts = inferenceParams?.AntiPrompts.ToArray() ?? Array.Empty(); + var n_past = 1; + inferenceParams ??= new InferenceParams(); + + var lastTokens = new List(inferenceParams.RepeatLastTokensCount); + for (var i = 0; i < inferenceParams.RepeatLastTokensCount; i++) + lastTokens.Add(0); + + var tokens = Context.Tokenize(text).ToList(); + var n_prompt_tokens = tokens.Count; + + Context.Eval(tokens, n_past); lastTokens.AddRange(tokens); n_past += n_prompt_tokens; var mu = (float?)null; - int max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; - for(int i = 0; i < max_tokens; i++) + var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; + for(var i = 0; i < max_tokens; i++) { if (cancellationToken.IsCancellationRequested) - { - _context.LoadState(_originalState); break; - } - var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? _context.ContextSize : inferenceParams.RepeatLastTokensCount; - var tokenDataArray = _context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, + var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? Context.ContextSize : inferenceParams.RepeatLastTokensCount; + + var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - var id = _context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, + var id = Context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar); lastTokens.Add(id); - string response = _context.NativeHandle.TokenToString(id, _context.Encoding); + var response = Context.TokenToString(id); yield return response; tokens.Clear(); tokens.Add(id); - if (inferenceParams.AntiPrompts is not null && inferenceParams.AntiPrompts.Count() > 0) - { - string last_output = ""; - foreach (var token in lastTokens) - { - last_output += _context.NativeHandle.TokenToString(token, _context.Encoding); - } - - bool should_break = false; - foreach (var antiprompt in inferenceParams.AntiPrompts) - { - if (last_output.EndsWith(antiprompt)) - { - should_break = true; - break; - } - } - if (should_break) - { - break; - } - } + if (EndsWithAntiprompt(lastTokens, antiprompts)) + break; // when run out of context - if (n_past + tokens.Count > _context.ContextSize) + // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L433 + if (n_past + tokens.Count > Context.ContextSize) { - int n_left = n_past - inferenceParams.TokensKeep; + var n_left = n_past - inferenceParams.TokensKeep; n_past = Math.Max(1, inferenceParams.TokensKeep); - // insert n_left/2 tokens at the start of embed from last_n_tokens - tokens.InsertRange(0, lastTokens.Take(lastTokens.Count - tokens.Count).Skip(_context.ContextSize - n_left / 2 - tokens.Count)); + tokens.Clear(); + tokens.AddRange(lastTokens.Skip(lastTokens.Count - n_left / 2).Take(n_left / 2)); } - n_past = _context.Eval(tokens.ToArray(), n_past); + n_past = Context.Eval(tokens, n_past); + } + } + + /// + /// Check if the given tokens list ends with any of the antiprompts + /// + /// + /// + /// + private bool EndsWithAntiprompt(IReadOnlyList tokens, IReadOnlyList antiprompts) + { + if (antiprompts.Count == 0 || tokens.Count == 0) + return false; + + var builder = new StringBuilder(); + foreach (var token in tokens) + builder.Append(Context.TokenToString(token)); + + var last_output = builder.ToString(); + + foreach (var antiprompt in antiprompts) + { + if (last_output.EndsWith(antiprompt)) + return true; } - _context.LoadState(_originalState); + return false; } /// diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index cb237a70..7d0ba1b0 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -20,9 +20,15 @@ namespace LLama /// Be careful how you use this! public SafeLlamaModelHandle NativeHandle => _weights; - private LLamaWeights(SafeLlamaModelHandle weights) + /// + /// Encoding to use to convert text into bytes for the model + /// + public Encoding Encoding { get; } + + internal LLamaWeights(SafeLlamaModelHandle weights, Encoding encoding) { _weights = weights; + Encoding = encoding; } /// @@ -38,7 +44,7 @@ namespace LLama if (!string.IsNullOrEmpty(@params.LoraAdapter)) weights.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads); - return new LLamaWeights(weights); + return new LLamaWeights(weights, Encoding.GetEncoding(@params.Encoding)); } /// @@ -51,11 +57,10 @@ namespace LLama /// Create a llama_context using this model /// /// - /// /// - public LLamaContext CreateContext(IModelParams @params, Encoding encoding) + public LLamaContext CreateContext(IModelParams @params) { - return new LLamaContext(this, @params, encoding); + return new LLamaContext(this, @params); } } } diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index 04663d77..2e499196 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -138,7 +138,6 @@ namespace LLama.Native /// Rows: n_tokens
/// Cols: n_vocab /// - /// /// public Span GetLogits() { @@ -179,12 +178,14 @@ namespace LLama.Native /// the number of tokens to use from previous eval calls /// /// Returns true on success - public bool Eval(Memory tokens, int n_past, int n_threads) + public bool Eval(ReadOnlySpan tokens, int n_past, int n_threads) { - using var pin = tokens.Pin(); unsafe { - return NativeApi.llama_eval_with_pointer(this, (int*)pin.Pointer, tokens.Length, n_past, n_threads) == 0; + fixed (int* pinned = tokens) + { + return NativeApi.llama_eval_with_pointer(this, pinned, tokens.Length, n_past, n_threads) == 0; + } } } } diff --git a/LLama/Utils.cs b/LLama/Utils.cs index 27eab2c6..45acad76 100644 --- a/LLama/Utils.cs +++ b/LLama/Utils.cs @@ -37,7 +37,7 @@ namespace LLama [Obsolete("Use SafeLLamaContextHandle Eval method instead")] public static int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads) { - var slice = tokens.AsMemory().Slice(startIndex, n_tokens); + var slice = tokens.AsSpan().Slice(startIndex, n_tokens); return ctx.Eval(slice, n_past, n_threads) ? 0 : 1; }