From 02a46fc3639cd0eb989425d9dc7db806dd0c20f5 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Thu, 17 Aug 2023 23:26:20 +0100 Subject: [PATCH 01/10] Updated demos to use the new loading/multi context system --- .../NewVersion/ChatSessionStripRoleName.cs | 17 ++++++++++------- .../NewVersion/ChatSessionWithRoleName.cs | 15 ++++++++------- LLama.Examples/NewVersion/GetEmbeddings.cs | 7 +------ .../NewVersion/InstructModeExecute.cs | 13 ++++++------- .../NewVersion/InteractiveModeExecute.cs | 15 +++++++-------- LLama.Examples/NewVersion/LoadAndSaveSession.cs | 16 ++++++++-------- LLama.Examples/NewVersion/LoadAndSaveState.cs | 17 ++++++++--------- LLama.Examples/NewVersion/QuantizeModel.cs | 14 +++++--------- .../NewVersion/StatelessModeExecute.cs | 11 +++++------ LLama.Examples/NewVersion/TestRunner.cs | 8 +------- 10 files changed, 59 insertions(+), 74 deletions(-) diff --git a/LLama.Examples/NewVersion/ChatSessionStripRoleName.cs b/LLama.Examples/NewVersion/ChatSessionStripRoleName.cs index 6402e360..230118e5 100644 --- a/LLama.Examples/NewVersion/ChatSessionStripRoleName.cs +++ b/LLama.Examples/NewVersion/ChatSessionStripRoleName.cs @@ -1,9 +1,5 @@ using LLama.Common; -using System; -using System.Collections.Generic; -using System.Linq; using System.Text; -using System.Threading.Tasks; namespace LLama.Examples.NewVersion { @@ -12,15 +8,22 @@ namespace LLama.Examples.NewVersion public static void Run() { Console.Write("Please input your model path: "); - string modelPath = Console.ReadLine(); + var modelPath = Console.ReadLine(); var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); - InteractiveExecutor ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); - ChatSession session = new ChatSession(ex).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" }, redundancyLength: 8)); + + var parameters = new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5); + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters, Encoding.UTF8); + var executor = new InteractiveExecutor(context); + + var session = new ChatSession(executor).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" }, redundancyLength: 8)); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The chat session has started. The role names won't be printed."); Console.ForegroundColor = ConsoleColor.White; + // show the prompt + Console.Write(prompt); while (true) { foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List { "User:" } })) diff --git a/LLama.Examples/NewVersion/ChatSessionWithRoleName.cs b/LLama.Examples/NewVersion/ChatSessionWithRoleName.cs index d1cbf34b..a3609388 100644 --- a/LLama.Examples/NewVersion/ChatSessionWithRoleName.cs +++ b/LLama.Examples/NewVersion/ChatSessionWithRoleName.cs @@ -1,9 +1,5 @@ using LLama.Common; -using System; -using System.Collections.Generic; -using System.Linq; using System.Text; -using System.Threading.Tasks; namespace LLama.Examples.NewVersion { @@ -12,10 +8,15 @@ namespace LLama.Examples.NewVersion public static void Run() { Console.Write("Please input your model path: "); - string modelPath = Console.ReadLine(); + var modelPath = Console.ReadLine(); var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); - InteractiveExecutor ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); - ChatSession session = new ChatSession(ex); // The only change is to remove the transform for the output text stream. + + var parameters = new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5); + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters, Encoding.UTF8); + var executor = new InteractiveExecutor(context); + + var session = new ChatSession(executor); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The chat session has started. In this example, the prompt is printed for better visual result."); diff --git a/LLama.Examples/NewVersion/GetEmbeddings.cs b/LLama.Examples/NewVersion/GetEmbeddings.cs index ed12f868..516d2da7 100644 --- a/LLama.Examples/NewVersion/GetEmbeddings.cs +++ b/LLama.Examples/NewVersion/GetEmbeddings.cs @@ -1,9 +1,4 @@ using LLama.Common; -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; namespace LLama.Examples.NewVersion { @@ -12,7 +7,7 @@ namespace LLama.Examples.NewVersion public static void Run() { Console.Write("Please input your model path: "); - string modelPath = Console.ReadLine(); + var modelPath = Console.ReadLine(); var embedder = new LLamaEmbedder(new ModelParams(modelPath)); while (true) diff --git a/LLama.Examples/NewVersion/InstructModeExecute.cs b/LLama.Examples/NewVersion/InstructModeExecute.cs index f81f2f58..0a384062 100644 --- a/LLama.Examples/NewVersion/InstructModeExecute.cs +++ b/LLama.Examples/NewVersion/InstructModeExecute.cs @@ -1,9 +1,5 @@ using LLama.Common; -using System; -using System.Collections.Generic; -using System.Linq; using System.Text; -using System.Threading.Tasks; namespace LLama.Examples.NewVersion { @@ -12,10 +8,13 @@ namespace LLama.Examples.NewVersion public static void Run() { Console.Write("Please input your model path: "); - string modelPath = Console.ReadLine(); + var modelPath = Console.ReadLine(); var prompt = File.ReadAllText("Assets/dan.txt").Trim(); - InstructExecutor ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 1024))); + var parameters = new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5); + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters, Encoding.UTF8); + var executor = new InstructExecutor(context); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example, the LLM will follow your instructions. For example, you can input \"Write a story about a fox who want to " + @@ -26,7 +25,7 @@ namespace LLama.Examples.NewVersion while (true) { - foreach (var text in ex.Infer(prompt, inferenceParams)) + foreach (var text in executor.Infer(prompt, inferenceParams)) { Console.Write(text); } diff --git a/LLama.Examples/NewVersion/InteractiveModeExecute.cs b/LLama.Examples/NewVersion/InteractiveModeExecute.cs index aaacabbe..9fee007f 100644 --- a/LLama.Examples/NewVersion/InteractiveModeExecute.cs +++ b/LLama.Examples/NewVersion/InteractiveModeExecute.cs @@ -1,21 +1,20 @@ using LLama.Common; -using System; -using System.Collections.Generic; -using System.Linq; using System.Text; -using System.Threading.Tasks; namespace LLama.Examples.NewVersion { public class InteractiveModeExecute { - public async static Task Run() + public static async Task Run() { Console.Write("Please input your model path: "); - string modelPath = Console.ReadLine(); - var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); + var modelPath = Console.ReadLine(); + var prompt = (await File.ReadAllTextAsync("Assets/chat-with-bob.txt")).Trim(); - InteractiveExecutor ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 256))); + var parameters = new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5); + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters, Encoding.UTF8); + var ex = new InteractiveExecutor(context); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to 128 and the context size is 256. (an example for small scale usage)"); diff --git a/LLama.Examples/NewVersion/LoadAndSaveSession.cs b/LLama.Examples/NewVersion/LoadAndSaveSession.cs index cbed9179..5e5c4252 100644 --- a/LLama.Examples/NewVersion/LoadAndSaveSession.cs +++ b/LLama.Examples/NewVersion/LoadAndSaveSession.cs @@ -1,10 +1,5 @@ using LLama.Common; -using LLama.OldVersion; -using System; -using System.Collections.Generic; -using System.Linq; using System.Text; -using System.Threading.Tasks; namespace LLama.Examples.NewVersion { @@ -13,10 +8,15 @@ namespace LLama.Examples.NewVersion public static void Run() { Console.Write("Please input your model path: "); - string modelPath = Console.ReadLine(); + var modelPath = Console.ReadLine(); var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); - InteractiveExecutor ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5))); - ChatSession session = new ChatSession(ex); // The only change is to remove the transform for the output text stream. + + var parameters = new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5); + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters, Encoding.UTF8); + var ex = new InteractiveExecutor(context); + + var session = new ChatSession(ex); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The chat session has started. In this example, the prompt is printed for better visual result. Input \"save\" to save and reload the session."); diff --git a/LLama.Examples/NewVersion/LoadAndSaveState.cs b/LLama.Examples/NewVersion/LoadAndSaveState.cs index 15f2f815..1a1c0d88 100644 --- a/LLama.Examples/NewVersion/LoadAndSaveState.cs +++ b/LLama.Examples/NewVersion/LoadAndSaveState.cs @@ -1,9 +1,5 @@ using LLama.Common; -using System; -using System.Collections.Generic; -using System.Linq; using System.Text; -using System.Threading.Tasks; namespace LLama.Examples.NewVersion { @@ -12,10 +8,13 @@ namespace LLama.Examples.NewVersion public static void Run() { Console.Write("Please input your model path: "); - string modelPath = Console.ReadLine(); + var modelPath = Console.ReadLine(); var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); - InteractiveExecutor ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 256))); + var parameters = new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5); + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters, Encoding.UTF8); + var ex = new InteractiveExecutor(context); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to 64 and the context size is 256. (an example for small scale usage)"); @@ -47,9 +46,9 @@ namespace LLama.Examples.NewVersion Console.WriteLine("All states saved!"); Console.ForegroundColor = ConsoleColor.White; - var model = ex.Context; - model.LoadState(modelStatePath); - ex = new InteractiveExecutor(model); + var ctx = ex.Context; + ctx.LoadState(modelStatePath); + ex = new InteractiveExecutor(ctx); ex.LoadState(executorStatePath); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("Loaded state!"); diff --git a/LLama.Examples/NewVersion/QuantizeModel.cs b/LLama.Examples/NewVersion/QuantizeModel.cs index a5ad81d8..71966af8 100644 --- a/LLama.Examples/NewVersion/QuantizeModel.cs +++ b/LLama.Examples/NewVersion/QuantizeModel.cs @@ -1,11 +1,4 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading; -using System.Threading.Tasks; - -namespace LLama.Examples.NewVersion +namespace LLama.Examples.NewVersion { public class QuantizeModel { @@ -13,13 +6,16 @@ namespace LLama.Examples.NewVersion { Console.Write("Please input your original model path: "); var inputPath = Console.ReadLine(); + Console.Write("Please input your output model path: "); var outputPath = Console.ReadLine(); + Console.Write("Please input the quantize type (one of q4_0, q4_1, q5_0, q5_1, q8_0): "); var quantizeType = Console.ReadLine(); + if (LLamaQuantizer.Quantize(inputPath, outputPath, quantizeType)) { - Console.WriteLine("Quantization succeed!"); + Console.WriteLine("Quantization succeeded!"); } else { diff --git a/LLama.Examples/NewVersion/StatelessModeExecute.cs b/LLama.Examples/NewVersion/StatelessModeExecute.cs index 8ff2c0a1..dadaf70a 100644 --- a/LLama.Examples/NewVersion/StatelessModeExecute.cs +++ b/LLama.Examples/NewVersion/StatelessModeExecute.cs @@ -1,9 +1,5 @@ using LLama.Common; -using System; -using System.Collections.Generic; -using System.Linq; using System.Text; -using System.Threading.Tasks; namespace LLama.Examples.NewVersion { @@ -12,9 +8,12 @@ namespace LLama.Examples.NewVersion public static void Run() { Console.Write("Please input your model path: "); - string modelPath = Console.ReadLine(); + var modelPath = Console.ReadLine(); - StatelessExecutor ex = new(new LLamaContext(new ModelParams(modelPath, contextSize: 256))); + var parameters = new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5); + using var model = LLamaWeights.LoadFromFile(parameters); + using var context = model.CreateContext(parameters, Encoding.UTF8); + var ex = new StatelessExecutor(context); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example, the inference is an one-time job. That says, the previous input and response has " + diff --git a/LLama.Examples/NewVersion/TestRunner.cs b/LLama.Examples/NewVersion/TestRunner.cs index c90bc78d..6cc3f3da 100644 --- a/LLama.Examples/NewVersion/TestRunner.cs +++ b/LLama.Examples/NewVersion/TestRunner.cs @@ -1,10 +1,4 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; - -namespace LLama.Examples.NewVersion +namespace LLama.Examples.NewVersion { public class NewVersionTestRunner { From 592a80840b4e8039a03856dd32d785daa9554e28 Mon Sep 17 00:00:00 2001 From: Erin Loy Date: Sat, 19 Aug 2023 15:55:19 -0700 Subject: [PATCH 02/10] renamed some arguments in ModelParams constructor so that classcan be serialized easily --- LLama/Common/ModelParams.cs | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs index 5cb81078..2230c70c 100644 --- a/LLama/Common/ModelParams.cs +++ b/LLama/Common/ModelParams.cs @@ -128,17 +128,17 @@ namespace LLama.Common /// Batch size for prompt processing (must be >=32 to use BLAS) (n_batch) /// Whether to convert eos to newline during the inference. /// Whether to use embedding mode. (embedding) Note that if this is set to true, The LLamaModel won't produce text response anymore. - /// Grouped-Query Attention - /// RMS Norm Epsilon - /// RoPE base frequency. - /// RoPE frequency scaling factor - /// Use experimental mul_mat_q kernels + /// Grouped-Query Attention + /// RMS Norm Epsilon + /// RoPE base frequency. + /// RoPE frequency scaling factor + /// Use experimental mul_mat_q kernels public ModelParams(string modelPath, int contextSize = 512, int gpuLayerCount = 20, int seed = 1337, bool useFp16Memory = true, bool useMemorymap = true, bool useMemoryLock = false, bool perplexity = false, string loraAdapter = "", string loraBase = "", int threads = -1, int batchSize = 512, bool convertEosToNewLine = false, bool embeddingMode = false, - int gqa = 1, float rmsNormEps = 5e-6f, float rope_freq_base = 10000.0f, float rope_freq_scale = 1f, bool muMatQ = false) + int groupedQueryAttention = 1, float rmsNormEpsilon = 5e-6f, float ropeFrequencyBase = 10000.0f, float ropeFrequencyScale = 1f, bool mulMatQ = false) { ContextSize = contextSize; GpuLayerCount = gpuLayerCount; @@ -154,11 +154,11 @@ namespace LLama.Common BatchSize = batchSize; ConvertEosToNewLine = convertEosToNewLine; EmbeddingMode = embeddingMode; - GroupedQueryAttention = gqa; - RmsNormEpsilon = rmsNormEps; - RopeFrequencyBase = rope_freq_base; - RopeFrequencyScale = rope_freq_scale; - MulMatQ = muMatQ; + GroupedQueryAttention = groupedQueryAttention; + RmsNormEpsilon = rmsNormEpsilon; + RopeFrequencyBase = ropeFrequencyBase; + RopeFrequencyScale = ropeFrequencyScale; + MulMatQ = mulMatQ; } } } From 4d0c044b9f537d80c7950bb0de738a34d709889b Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Fri, 18 Aug 2023 00:37:30 +0100 Subject: [PATCH 03/10] Added tests for the StatelessExecutor, one is currently failing --- LLama.Unittest/StatelessExecutorTest.cs | 59 +++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 LLama.Unittest/StatelessExecutorTest.cs diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs new file mode 100644 index 00000000..fe3e6e03 --- /dev/null +++ b/LLama.Unittest/StatelessExecutorTest.cs @@ -0,0 +1,59 @@ +using LLama.Common; +using System.Text; + +namespace LLama.Unittest +{ + public class StatelessExecutorTest + : IDisposable + { + private readonly LLamaWeights _weights; + private readonly ModelParams _params; + + public StatelessExecutorTest() + { + _params = new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin") + { + ContextSize = 64, + Seed = 1754 + }; + _weights = LLamaWeights.LoadFromFile(_params); + } + + public void Dispose() + { + _weights.Dispose(); + } + + [Fact] + public void Stateless() + { + var executor = new StatelessExecutor(_weights.CreateContext(_params, Encoding.UTF8)); + + const string question = "Question. what is a cat?\nAnswer: "; + const string expected = " a domestic or wild animal that is typically small to medium-sized, has fur, four legs, and sharp retractable claws."; + var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } }; + + var result1 = string.Join("", executor.Infer(question, @params)); + Assert.Equal(expected, result1); + + var result2 = string.Join("", executor.Infer(question, @params)); + Assert.Equal(expected, result2); + + Assert.Equal(result1, result2); + } + + [Fact] + public void OutOfContext() + { + var executor = new StatelessExecutor(_weights.CreateContext(_params, Encoding.UTF8)); + + const string question = "Question. why is a cat the best pet?\nAnswer: "; + var @params = new InferenceParams() + { + MaxTokens = 128, + }; + + var result1 = string.Join("", executor.Infer(question, @params)); + } + } +} \ No newline at end of file From ae8ef17a4a30659fe621de79305c9a8bab17857c Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Fri, 18 Aug 2023 00:39:25 +0100 Subject: [PATCH 04/10] - Added various convenience overloads to `LLamaContext.Eval` - Converted `SafeLLamaContextHandle` to take a `ReadOnlySpan` for Eval, narrower type better represents what's really needed --- .../NewVersion/StatelessModeExecute.cs | 2 +- LLama/LLamaContext.cs | 71 ++++++++++++++++++- LLama/Native/SafeLLamaContextHandle.cs | 8 ++- LLama/Utils.cs | 2 +- 4 files changed, 76 insertions(+), 7 deletions(-) diff --git a/LLama.Examples/NewVersion/StatelessModeExecute.cs b/LLama.Examples/NewVersion/StatelessModeExecute.cs index dadaf70a..d43cdd79 100644 --- a/LLama.Examples/NewVersion/StatelessModeExecute.cs +++ b/LLama.Examples/NewVersion/StatelessModeExecute.cs @@ -29,7 +29,7 @@ namespace LLama.Examples.NewVersion Console.Write("\nQuestion: "); Console.ForegroundColor = ConsoleColor.Green; string prompt = Console.ReadLine(); - Console.ForegroundColor = ConsoleColor.White; + Console.ForegroundColor = ConsoleColor.White; Console.Write("Answer: "); prompt = $"Question: {prompt.Trim()} Answer: "; foreach (var text in ex.Infer(prompt, inferenceParams)) diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 1ef2a8db..4fb601d4 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -1,6 +1,7 @@ using LLama.Exceptions; using LLama.Native; using System; +using System.Buffers; using System.Collections.Generic; using System.Linq; using System.Text; @@ -384,6 +385,7 @@ namespace LLama return candidates_p; } + #region eval overloads /// /// /// @@ -391,7 +393,61 @@ namespace LLama /// /// The updated `pastTokensCount`. /// - public llama_token Eval(llama_token[] tokens, llama_token pastTokensCount) + public int Eval(llama_token[] tokens, llama_token pastTokensCount) + { + return Eval(tokens.AsSpan(), pastTokensCount); + } + + /// + /// + /// + /// + /// + /// The updated `pastTokensCount`. + /// + public int Eval(List tokens, llama_token pastTokensCount) + { +#if NET5_0_OR_GREATER + var span = CollectionsMarshal.AsSpan(tokens); + return Eval(span, pastTokensCount); +#else + // on netstandard2.0 we can't use collections marshal to get directly at the internal memory of + // the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't + // avoid the copying. + + var rented = ArrayPool.Shared.Rent(tokens.Count); + try + { + tokens.CopyTo(rented, 0); + return Eval(rented, pastTokensCount); + } + finally + { + ArrayPool.Shared.Return(rented); + } +#endif + } + + /// + /// + /// + /// + /// + /// The updated `pastTokensCount`. + /// + public int Eval(ReadOnlyMemory tokens, llama_token pastTokensCount) + { + return Eval(tokens.Span, pastTokensCount); + } + + /// + /// + /// + /// + /// + /// The updated `pastTokensCount`. + /// + public int Eval(ReadOnlySpan tokens, llama_token pastTokensCount) { int total = tokens.Length; for(int i = 0; i < total; i += Params.BatchSize) @@ -402,7 +458,7 @@ namespace LLama n_eval = Params.BatchSize; } - if (!_ctx.Eval(tokens.AsMemory(i, n_eval), pastTokensCount, Params.Threads)) + if (!_ctx.Eval(tokens.Slice(i, n_eval), pastTokensCount, Params.Threads)) { _logger?.Log(nameof(LLamaContext), "Failed to eval.", ILLamaLogger.LogLevel.Error); throw new RuntimeError("Failed to eval."); @@ -412,6 +468,7 @@ namespace LLama } return pastTokensCount; } +#endregion internal IEnumerable GenerateResult(IEnumerable ids) { @@ -419,6 +476,16 @@ namespace LLama yield return _ctx.TokenToString(id, _encoding); } + /// + /// Convert a token into a string + /// + /// + /// + public string TokenToString(llama_token token) + { + return NativeHandle.TokenToString(token, Encoding); + } + /// public virtual void Dispose() { diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index 04663d77..bb5911fe 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -179,12 +179,14 @@ namespace LLama.Native /// the number of tokens to use from previous eval calls /// /// Returns true on success - public bool Eval(Memory tokens, int n_past, int n_threads) + public bool Eval(ReadOnlySpan tokens, int n_past, int n_threads) { - using var pin = tokens.Pin(); unsafe { - return NativeApi.llama_eval_with_pointer(this, (int*)pin.Pointer, tokens.Length, n_past, n_threads) == 0; + fixed (int* pinned = tokens) + { + return NativeApi.llama_eval_with_pointer(this, pinned, tokens.Length, n_past, n_threads) == 0; + } } } } diff --git a/LLama/Utils.cs b/LLama/Utils.cs index 27eab2c6..45acad76 100644 --- a/LLama/Utils.cs +++ b/LLama/Utils.cs @@ -37,7 +37,7 @@ namespace LLama [Obsolete("Use SafeLLamaContextHandle Eval method instead")] public static int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads) { - var slice = tokens.AsMemory().Slice(startIndex, n_tokens); + var slice = tokens.AsSpan().Slice(startIndex, n_tokens); return ctx.Eval(slice, n_past, n_threads) ? 0 : 1; } From 4738c26299515aa837b1b11755785cc0058de285 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Fri, 18 Aug 2023 01:29:57 +0100 Subject: [PATCH 05/10] - Reduced context size of test, to speed it up - Removed some unnecessary `ToArray` calls - Initial pass on LLamaStatelessExecutor, the context overflow management is broken but I think I found where it's ported from --- LLama.Unittest/StatelessExecutorTest.cs | 11 ++- LLama/Extensions/ListExtensions.cs | 14 ++++ LLama/LLamaContext.cs | 10 +-- LLama/LLamaInstructExecutor.cs | 2 +- LLama/LLamaInteractExecutor.cs | 2 +- LLama/LLamaStatelessExecutor.cs | 95 ++++++++++++++----------- 6 files changed, 81 insertions(+), 53 deletions(-) create mode 100644 LLama/Extensions/ListExtensions.cs diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs index fe3e6e03..0af49b15 100644 --- a/LLama.Unittest/StatelessExecutorTest.cs +++ b/LLama.Unittest/StatelessExecutorTest.cs @@ -13,7 +13,7 @@ namespace LLama.Unittest { _params = new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin") { - ContextSize = 64, + ContextSize = 40, Seed = 1754 }; _weights = LLamaWeights.LoadFromFile(_params); @@ -48,12 +48,17 @@ namespace LLama.Unittest var executor = new StatelessExecutor(_weights.CreateContext(_params, Encoding.UTF8)); const string question = "Question. why is a cat the best pet?\nAnswer: "; + const string answer = ""; + var @params = new InferenceParams() { - MaxTokens = 128, + MaxTokens = 50, + TokensKeep = question.Length, }; - var result1 = string.Join("", executor.Infer(question, @params)); + var result = string.Join("", executor.Infer(question, @params)); + + Assert.Equal(answer, result); } } } \ No newline at end of file diff --git a/LLama/Extensions/ListExtensions.cs b/LLama/Extensions/ListExtensions.cs new file mode 100644 index 00000000..c78d311c --- /dev/null +++ b/LLama/Extensions/ListExtensions.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; + +namespace LLama.Extensions +{ + internal static class ListExtensions + { + public static void AddRangeSpan(this List list, ReadOnlySpan span) + { + for (var i = 0; i < span.Length; i++) + list.Add(span[i]); + } + } +} diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 4fb601d4..75e30282 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -411,7 +411,7 @@ namespace LLama var span = CollectionsMarshal.AsSpan(tokens); return Eval(span, pastTokensCount); #else - // on netstandard2.0 we can't use collections marshal to get directly at the internal memory of + // on netstandard2.0 we can't use CollectionsMarshal to get directly at the internal memory of // the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't // avoid the copying. @@ -449,11 +449,11 @@ namespace LLama /// public int Eval(ReadOnlySpan tokens, llama_token pastTokensCount) { - int total = tokens.Length; - for(int i = 0; i < total; i += Params.BatchSize) + var total = tokens.Length; + for(var i = 0; i < total; i += Params.BatchSize) { - int n_eval = total - i; - if(n_eval > Params.BatchSize) + var n_eval = total - i; + if (n_eval > Params.BatchSize) { n_eval = Params.BatchSize; } diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 6773cdde..d1396853 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -189,7 +189,7 @@ namespace LLama } TryReuseMathingPrefix(); - _pastTokensCount = Context.Eval(_embeds.ToArray(), _pastTokensCount); + _pastTokensCount = Context.Eval(_embeds, _pastTokensCount); if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) { diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 533a1863..efaeeac9 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -178,7 +178,7 @@ namespace LLama } TryReuseMathingPrefix(); - _pastTokensCount = Context.Eval(_embeds.ToArray(), _pastTokensCount); + _pastTokensCount = Context.Eval(_embeds, _pastTokensCount); if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) { diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index f09ff7dd..f78ac0da 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -1,10 +1,10 @@ using LLama.Abstractions; using LLama.Common; -using LLama.Native; using System; using System.Collections.Generic; using System.Linq; using System.Runtime.CompilerServices; +using System.Text; using System.Threading; namespace LLama @@ -16,12 +16,14 @@ namespace LLama /// public class StatelessExecutor : ILLamaExecutor { - private LLamaContext _context; - private LLamaContext.State _originalState; + private readonly LLamaContext _context; + private readonly LLamaContext.State _originalState; + /// /// The context used by the executor when running the inference. /// public LLamaContext Context => _context; + /// /// /// @@ -31,7 +33,7 @@ namespace LLama _context = context; var tokens = context.Tokenize(" ", true).ToArray(); - _context.NativeHandle.Eval(tokens.AsMemory(0, tokens.Length), 0, _context.Params.Threads); + _context.NativeHandle.Eval(tokens.AsSpan(0, tokens.Length), 0, _context.Params.Threads); _originalState = context.GetState(); } @@ -39,27 +41,26 @@ namespace LLama public IEnumerable Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) { cancellationToken.ThrowIfCancellationRequested(); - int n_past = 1; - if(inferenceParams is null) - { - inferenceParams = new InferenceParams(); - } - List lastTokens = new(inferenceParams.RepeatLastTokensCount); - for(int i = 0; i < lastTokens.Count; i++) - { - lastTokens[i] = 0; - } - List tokens = _context.Tokenize(text, true).ToList(); - int n_prompt_tokens = tokens.Count; - _context.NativeHandle.Eval(tokens.ToArray().AsMemory(0, n_prompt_tokens), n_past, _context.Params.Threads); + var antiprompts = inferenceParams?.AntiPrompts.ToArray() ?? Array.Empty(); + var n_past = 1; + inferenceParams ??= new InferenceParams(); + + var lastTokens = new List(inferenceParams.RepeatLastTokensCount); + for (var i = 0; i < inferenceParams.RepeatLastTokensCount; i++) + lastTokens.Add(0); + + var tokens = _context.Tokenize(text).ToList(); + var n_prompt_tokens = tokens.Count; + + _context.Eval(tokens, n_past); lastTokens.AddRange(tokens); n_past += n_prompt_tokens; var mu = (float?)null; - int max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; - for(int i = 0; i < max_tokens; i++) + var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; + for(var i = 0; i < max_tokens; i++) { if (cancellationToken.IsCancellationRequested) { @@ -76,35 +77,17 @@ namespace LLama lastTokens.Add(id); - string response = _context.NativeHandle.TokenToString(id, _context.Encoding); + var response = _context.TokenToString(id); yield return response; tokens.Clear(); tokens.Add(id); - if (inferenceParams.AntiPrompts is not null && inferenceParams.AntiPrompts.Count() > 0) - { - string last_output = ""; - foreach (var token in lastTokens) - { - last_output += _context.NativeHandle.TokenToString(token, _context.Encoding); - } - - bool should_break = false; - foreach (var antiprompt in inferenceParams.AntiPrompts) - { - if (last_output.EndsWith(antiprompt)) - { - should_break = true; - break; - } - } - if (should_break) - { - break; - } - } + if (EndsWithAntiprompt(lastTokens, antiprompts)) + break; + // todo: this seems to be based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L433 + // todo: but it's broken! // when run out of context if (n_past + tokens.Count > _context.ContextSize) { @@ -116,12 +99,38 @@ namespace LLama tokens.InsertRange(0, lastTokens.Take(lastTokens.Count - tokens.Count).Skip(_context.ContextSize - n_left / 2 - tokens.Count)); } - n_past = _context.Eval(tokens.ToArray(), n_past); + n_past = _context.Eval(tokens, n_past); } _context.LoadState(_originalState); } + /// + /// Check if the given tokens list ends with any of the antiprompts + /// + /// + /// + /// + private bool EndsWithAntiprompt(IReadOnlyList tokens, IReadOnlyList antiprompts) + { + if (antiprompts.Count == 0 || tokens.Count == 0) + return false; + + var builder = new StringBuilder(); + foreach (var token in tokens) + builder.Append(_context.TokenToString(token)); + + var last_output = builder.ToString(); + + foreach (var antiprompt in antiprompts) + { + if (last_output.EndsWith(antiprompt)) + return true; + } + + return false; + } + /// public async IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { From e7b217f4620a62c651918715e66a18c70e78661e Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sat, 19 Aug 2023 17:27:04 +0100 Subject: [PATCH 06/10] Fixed out of context logic --- LLama.Unittest/StatelessExecutorTest.cs | 14 ++++++++++---- LLama/LLamaStatelessExecutor.cs | 19 +++++++++++++------ 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs index 0af49b15..1c69a4e6 100644 --- a/LLama.Unittest/StatelessExecutorTest.cs +++ b/LLama.Unittest/StatelessExecutorTest.cs @@ -13,7 +13,7 @@ namespace LLama.Unittest { _params = new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin") { - ContextSize = 40, + ContextSize = 60, Seed = 1754 }; _weights = LLamaWeights.LoadFromFile(_params); @@ -47,12 +47,18 @@ namespace LLama.Unittest { var executor = new StatelessExecutor(_weights.CreateContext(_params, Encoding.UTF8)); - const string question = "Question. why is a cat the best pet?\nAnswer: "; - const string answer = ""; + const string question = " Question. why is a cat the best pet?\nAnswer: "; + const string answer = " there are many reasons why cats make excellent pets! here are just a few of them:\n" + + "1)Loyalty: Cats are known for their loyalty to their owners, and they will often follow " + + "you around the house if you call them. They will always come running when called, and they’ll " + + "nuzzle and purr with delight when you walk into the room! they adore being close to their human " + + "family members and can form very close bonds.\n"; + // The context size is set to 60. Generate more than that, forcing it to generate a coherent response + // with a modified context var @params = new InferenceParams() { - MaxTokens = 50, + MaxTokens = 100, TokensKeep = question.Length, }; diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index f78ac0da..50eeee0e 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -10,11 +10,13 @@ using System.Threading; namespace LLama { using llama_token = Int32; + /// /// This executor infer the input as one-time job. Previous inputs won't impact on the /// response to current input. /// - public class StatelessExecutor : ILLamaExecutor + public class StatelessExecutor + : ILLamaExecutor { private readonly LLamaContext _context; private readonly LLamaContext.State _originalState; @@ -40,6 +42,12 @@ namespace LLama /// public IEnumerable Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) { + if (inferenceParams != null) + { + if (inferenceParams.TokensKeep > Context.ContextSize) + throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})"); + } + cancellationToken.ThrowIfCancellationRequested(); var antiprompts = inferenceParams?.AntiPrompts.ToArray() ?? Array.Empty(); @@ -86,17 +94,16 @@ namespace LLama if (EndsWithAntiprompt(lastTokens, antiprompts)) break; - // todo: this seems to be based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L433 - // todo: but it's broken! // when run out of context + // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L433 if (n_past + tokens.Count > _context.ContextSize) { - int n_left = n_past - inferenceParams.TokensKeep; + var n_left = n_past - inferenceParams.TokensKeep; n_past = Math.Max(1, inferenceParams.TokensKeep); - // insert n_left/2 tokens at the start of embed from last_n_tokens - tokens.InsertRange(0, lastTokens.Take(lastTokens.Count - tokens.Count).Skip(_context.ContextSize - n_left / 2 - tokens.Count)); + tokens.Clear(); + tokens.AddRange(lastTokens.Skip(lastTokens.Count - n_left / 2).Take(n_left / 2)); } n_past = _context.Eval(tokens, n_past); From 6f2ab8e0391b2b53c6737fa9fab9ce2b31fcda80 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Tue, 22 Aug 2023 00:32:36 +0100 Subject: [PATCH 07/10] Not asserting the answer, just that it didn't change --- LLama.Unittest/StatelessExecutorTest.cs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs index 1c69a4e6..b9f89cd6 100644 --- a/LLama.Unittest/StatelessExecutorTest.cs +++ b/LLama.Unittest/StatelessExecutorTest.cs @@ -30,15 +30,12 @@ namespace LLama.Unittest var executor = new StatelessExecutor(_weights.CreateContext(_params, Encoding.UTF8)); const string question = "Question. what is a cat?\nAnswer: "; - const string expected = " a domestic or wild animal that is typically small to medium-sized, has fur, four legs, and sharp retractable claws."; var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } }; var result1 = string.Join("", executor.Infer(question, @params)); - Assert.Equal(expected, result1); - var result2 = string.Join("", executor.Infer(question, @params)); - Assert.Equal(expected, result2); + // Check that it produced the exact same result both times Assert.Equal(result1, result2); } From 48bc0a6f8ac28bbd7b95e264c5244b9f463ca9b5 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Tue, 22 Aug 2023 00:39:19 +0100 Subject: [PATCH 08/10] Doe the same for the second test, hopefully fixing CI --- LLama.Unittest/StatelessExecutorTest.cs | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs index b9f89cd6..38098278 100644 --- a/LLama.Unittest/StatelessExecutorTest.cs +++ b/LLama.Unittest/StatelessExecutorTest.cs @@ -1,16 +1,19 @@ using LLama.Common; using System.Text; +using Xunit.Abstractions; namespace LLama.Unittest { public class StatelessExecutorTest : IDisposable { + private readonly ITestOutputHelper _testOutputHelper; private readonly LLamaWeights _weights; private readonly ModelParams _params; - public StatelessExecutorTest() + public StatelessExecutorTest(ITestOutputHelper testOutputHelper) { + _testOutputHelper = testOutputHelper; _params = new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin") { ContextSize = 60, @@ -35,6 +38,8 @@ namespace LLama.Unittest var result1 = string.Join("", executor.Infer(question, @params)); var result2 = string.Join("", executor.Infer(question, @params)); + _testOutputHelper.WriteLine(result1); + // Check that it produced the exact same result both times Assert.Equal(result1, result2); } @@ -45,11 +50,6 @@ namespace LLama.Unittest var executor = new StatelessExecutor(_weights.CreateContext(_params, Encoding.UTF8)); const string question = " Question. why is a cat the best pet?\nAnswer: "; - const string answer = " there are many reasons why cats make excellent pets! here are just a few of them:\n" + - "1)Loyalty: Cats are known for their loyalty to their owners, and they will often follow " + - "you around the house if you call them. They will always come running when called, and they’ll " + - "nuzzle and purr with delight when you walk into the room! they adore being close to their human " + - "family members and can form very close bonds.\n"; // The context size is set to 60. Generate more than that, forcing it to generate a coherent response // with a modified context @@ -59,9 +59,13 @@ namespace LLama.Unittest TokensKeep = question.Length, }; - var result = string.Join("", executor.Infer(question, @params)); + var result1 = string.Join("", executor.Infer(question, @params)); + var result2 = string.Join("", executor.Infer(question, @params)); + + _testOutputHelper.WriteLine(result1); - Assert.Equal(answer, result); + // Check that it produced the exact same result both times + Assert.Equal(result1, result2); } } } \ No newline at end of file From a9e6f21ab819ef2b44323a3df6bf338924088931 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Tue, 22 Aug 2023 01:25:45 +0100 Subject: [PATCH 09/10] - Creating and destroying contexts in the stateless executor, saving memory. It now uses zero memory when not inferring! - Passing encoding in the `IModelParams`, which reduces how often encoding needs to be passed around --- .../NewVersion/StatelessModeExecute.cs | 8 +- LLama.Examples/NewVersion/TalkToYourself.cs | 4 +- LLama.Unittest/LLamaContextTests.cs | 2 +- LLama.Unittest/StatelessExecutorTest.cs | 5 +- LLama.Web/Common/ModelOptions.cs | 9 ++- LLama/Abstractions/IModelParams.cs | 9 ++- LLama/Common/ModelParams.cs | 75 +++++++++++-------- LLama/LLamaContext.cs | 17 ++--- LLama/LLamaStatelessExecutor.cs | 66 ++++++++++------ LLama/LLamaWeights.cs | 15 ++-- LLama/Native/SafeLLamaContextHandle.cs | 1 - 11 files changed, 122 insertions(+), 89 deletions(-) diff --git a/LLama.Examples/NewVersion/StatelessModeExecute.cs b/LLama.Examples/NewVersion/StatelessModeExecute.cs index d43cdd79..7f59e73e 100644 --- a/LLama.Examples/NewVersion/StatelessModeExecute.cs +++ b/LLama.Examples/NewVersion/StatelessModeExecute.cs @@ -1,5 +1,4 @@ using LLama.Common; -using System.Text; namespace LLama.Examples.NewVersion { @@ -12,8 +11,7 @@ namespace LLama.Examples.NewVersion var parameters = new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5); using var model = LLamaWeights.LoadFromFile(parameters); - using var context = model.CreateContext(parameters, Encoding.UTF8); - var ex = new StatelessExecutor(context); + var ex = new StatelessExecutor(model, parameters); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example, the inference is an one-time job. That says, the previous input and response has " + @@ -28,10 +26,10 @@ namespace LLama.Examples.NewVersion { Console.Write("\nQuestion: "); Console.ForegroundColor = ConsoleColor.Green; - string prompt = Console.ReadLine(); + var prompt = Console.ReadLine(); Console.ForegroundColor = ConsoleColor.White; Console.Write("Answer: "); - prompt = $"Question: {prompt.Trim()} Answer: "; + prompt = $"Question: {prompt?.Trim()} Answer: "; foreach (var text in ex.Infer(prompt, inferenceParams)) { Console.Write(text); diff --git a/LLama.Examples/NewVersion/TalkToYourself.cs b/LLama.Examples/NewVersion/TalkToYourself.cs index 35a65241..309d5654 100644 --- a/LLama.Examples/NewVersion/TalkToYourself.cs +++ b/LLama.Examples/NewVersion/TalkToYourself.cs @@ -20,9 +20,9 @@ namespace LLama.Examples.NewVersion using var weights = LLamaWeights.LoadFromFile(@params); // Create 2 contexts sharing the same weights - using var aliceCtx = weights.CreateContext(@params, Encoding.UTF8); + using var aliceCtx = weights.CreateContext(@params); var alice = new InteractiveExecutor(aliceCtx); - using var bobCtx = weights.CreateContext(@params, Encoding.UTF8); + using var bobCtx = weights.CreateContext(@params); var bob = new InteractiveExecutor(bobCtx); // Initial alice prompt diff --git a/LLama.Unittest/LLamaContextTests.cs b/LLama.Unittest/LLamaContextTests.cs index a34f58cb..1c5fa952 100644 --- a/LLama.Unittest/LLamaContextTests.cs +++ b/LLama.Unittest/LLamaContextTests.cs @@ -16,7 +16,7 @@ namespace LLama.Unittest ContextSize = 768, }; _weights = LLamaWeights.LoadFromFile(@params); - _context = _weights.CreateContext(@params, Encoding.UTF8); + _context = _weights.CreateContext(@params); } public void Dispose() diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs index 38098278..37031da3 100644 --- a/LLama.Unittest/StatelessExecutorTest.cs +++ b/LLama.Unittest/StatelessExecutorTest.cs @@ -1,5 +1,4 @@ using LLama.Common; -using System.Text; using Xunit.Abstractions; namespace LLama.Unittest @@ -30,7 +29,7 @@ namespace LLama.Unittest [Fact] public void Stateless() { - var executor = new StatelessExecutor(_weights.CreateContext(_params, Encoding.UTF8)); + var executor = new StatelessExecutor(_weights, _params); const string question = "Question. what is a cat?\nAnswer: "; var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } }; @@ -47,7 +46,7 @@ namespace LLama.Unittest [Fact] public void OutOfContext() { - var executor = new StatelessExecutor(_weights.CreateContext(_params, Encoding.UTF8)); + var executor = new StatelessExecutor(_weights, _params); const string question = " Question. why is a cat the best pet?\nAnswer: "; diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs index e8b89dee..3f5a3f0c 100644 --- a/LLama.Web/Common/ModelOptions.cs +++ b/LLama.Web/Common/ModelOptions.cs @@ -2,7 +2,8 @@ namespace LLama.Web.Common { - public class ModelOptions : IModelParams + public class ModelOptions + : IModelParams { public string Name { get; set; } @@ -111,5 +112,9 @@ namespace LLama.Web.Common /// public bool MulMatQ { get; set; } - } + /// + /// The encoding to use for models + /// + public string Encoding { get; set; } = "UTF-8"; + } } diff --git a/LLama/Abstractions/IModelParams.cs b/LLama/Abstractions/IModelParams.cs index fdc91152..64a0125b 100644 --- a/LLama/Abstractions/IModelParams.cs +++ b/LLama/Abstractions/IModelParams.cs @@ -1,6 +1,4 @@ -using System; - -namespace LLama.Abstractions +namespace LLama.Abstractions { public interface IModelParams { @@ -119,5 +117,10 @@ namespace LLama.Abstractions /// Use experimental mul_mat_q kernels /// bool MulMatQ { get; set; } + + /// + /// The encoding to use for models + /// + string Encoding { get; set; } } } \ No newline at end of file diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs index 2230c70c..c0741abe 100644 --- a/LLama/Common/ModelParams.cs +++ b/LLama/Common/ModelParams.cs @@ -1,5 +1,6 @@ using LLama.Abstractions; using System; +using System.Text; namespace LLama.Common { @@ -111,34 +112,41 @@ namespace LLama.Common /// public bool MulMatQ { get; set; } - /// - /// - /// - /// The model path. - /// Model context size (n_ctx) - /// Number of layers to run in VRAM / GPU memory (n_gpu_layers) - /// Seed for the random number generator (seed) - /// Whether to use f16 instead of f32 for memory kv (memory_f16) - /// Whether to use mmap for faster loads (use_mmap) - /// Whether to use mlock to keep model in memory (use_mlock) - /// Thether to compute perplexity over the prompt (perplexity) - /// Lora adapter path (lora_adapter) - /// Base model path for the lora adapter (lora_base) - /// Number of threads (-1 = autodetect) (n_threads) - /// Batch size for prompt processing (must be >=32 to use BLAS) (n_batch) - /// Whether to convert eos to newline during the inference. - /// Whether to use embedding mode. (embedding) Note that if this is set to true, The LLamaModel won't produce text response anymore. - /// Grouped-Query Attention - /// RMS Norm Epsilon - /// RoPE base frequency. - /// RoPE frequency scaling factor - /// Use experimental mul_mat_q kernels - public ModelParams(string modelPath, int contextSize = 512, int gpuLayerCount = 20, - int seed = 1337, bool useFp16Memory = true, - bool useMemorymap = true, bool useMemoryLock = false, bool perplexity = false, - string loraAdapter = "", string loraBase = "", int threads = -1, int batchSize = 512, - bool convertEosToNewLine = false, bool embeddingMode = false, - int groupedQueryAttention = 1, float rmsNormEpsilon = 5e-6f, float ropeFrequencyBase = 10000.0f, float ropeFrequencyScale = 1f, bool mulMatQ = false) + /// + /// The encoding to use to convert text for the model + /// + public string Encoding { get; set; } = "UTF-8"; + + /// + /// + /// + /// The model path. + /// Model context size (n_ctx) + /// Number of layers to run in VRAM / GPU memory (n_gpu_layers) + /// Seed for the random number generator (seed) + /// Whether to use f16 instead of f32 for memory kv (memory_f16) + /// Whether to use mmap for faster loads (use_mmap) + /// Whether to use mlock to keep model in memory (use_mlock) + /// Thether to compute perplexity over the prompt (perplexity) + /// Lora adapter path (lora_adapter) + /// Base model path for the lora adapter (lora_base) + /// Number of threads (-1 = autodetect) (n_threads) + /// Batch size for prompt processing (must be >=32 to use BLAS) (n_batch) + /// Whether to convert eos to newline during the inference. + /// Whether to use embedding mode. (embedding) Note that if this is set to true, The LLamaModel won't produce text response anymore. + /// Grouped-Query Attention + /// RMS Norm Epsilon + /// RoPE base frequency. + /// RoPE frequency scaling factor + /// Use experimental mul_mat_q kernels + /// The encoding to use to convert text for the model + public ModelParams(string modelPath, int contextSize = 512, int gpuLayerCount = 20, + int seed = 1337, bool useFp16Memory = true, + bool useMemorymap = true, bool useMemoryLock = false, bool perplexity = false, + string loraAdapter = "", string loraBase = "", int threads = -1, int batchSize = 512, + bool convertEosToNewLine = false, bool embeddingMode = false, + int groupedQueryAttention = 1, float rmsNormEps = 5e-6f, float ropeFreqBase = 10000.0f, float ropeFreqScale = 1f, bool muMatQ = false, + string encoding = "UTF-8") { ContextSize = contextSize; GpuLayerCount = gpuLayerCount; @@ -155,10 +163,11 @@ namespace LLama.Common ConvertEosToNewLine = convertEosToNewLine; EmbeddingMode = embeddingMode; GroupedQueryAttention = groupedQueryAttention; - RmsNormEpsilon = rmsNormEpsilon; - RopeFrequencyBase = ropeFrequencyBase; - RopeFrequencyScale = ropeFrequencyScale; - MulMatQ = mulMatQ; - } + RmsNormEpsilon = rmsNormEps; + RopeFrequencyBase = ropeFreqBase; + RopeFrequencyScale = ropeFreqScale; + MulMatQ = muMatQ; + Encoding = encoding; + } } } diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 75e30282..095f44af 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -10,7 +10,6 @@ using System.IO.MemoryMappedFiles; using LLama.Common; using System.Runtime.InteropServices; using LLama.Extensions; -using Microsoft.Win32.SafeHandles; using LLama.Abstractions; namespace LLama @@ -62,26 +61,25 @@ namespace LLama /// /// /// Model params. - /// Encoding to deal with text input. /// The logger. [Obsolete("Use the LLamaWeights.CreateContext instead")] - public LLamaContext(IModelParams @params, string encoding = "UTF-8", ILLamaLogger? logger = null) + public LLamaContext(IModelParams @params, ILLamaLogger? logger = null) { Params = @params; _logger = logger; - _encoding = Encoding.GetEncoding(encoding); + _encoding = Encoding.GetEncoding(@params.Encoding); _logger?.Log(nameof(LLamaContext), $"Initializing LLama model with params: {this.Params}", ILLamaLogger.LogLevel.Info); _ctx = Utils.InitLLamaContextFromModelParams(Params); } - internal LLamaContext(SafeLLamaContextHandle nativeContext, IModelParams @params, Encoding encoding, ILLamaLogger? logger = null) + internal LLamaContext(SafeLLamaContextHandle nativeContext, IModelParams @params, ILLamaLogger? logger = null) { Params = @params; _logger = logger; - _encoding = encoding; + _encoding = Encoding.GetEncoding(@params.Encoding); _ctx = nativeContext; } @@ -90,10 +88,9 @@ namespace LLama /// /// /// - /// /// /// - public LLamaContext(LLamaWeights model, IModelParams @params, Encoding encoding, ILLamaLogger? logger = null) + public LLamaContext(LLamaWeights model, IModelParams @params, ILLamaLogger? logger = null) { if (model.NativeHandle.IsClosed) throw new ObjectDisposedException("Cannot create context, model weights have been disposed"); @@ -101,7 +98,7 @@ namespace LLama Params = @params; _logger = logger; - _encoding = encoding; + _encoding = Encoding.GetEncoding(@params.Encoding); using var pin = @params.ToLlamaContextParams(out var lparams); _ctx = SafeLLamaContextHandle.Create(model.NativeHandle, lparams); @@ -116,7 +113,7 @@ namespace LLama using var pin = Params.ToLlamaContextParams(out var lparams); // Create a blank new context for the model - var ctx = new LLamaContext(SafeLLamaContextHandle.Create(NativeHandle.ModelHandle, lparams), Params, _encoding); + var ctx = new LLamaContext(SafeLLamaContextHandle.Create(NativeHandle.ModelHandle, lparams), Params); // Copy across the state using var state = GetState(); diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 50eeee0e..d86be5e7 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -18,30 +18,52 @@ namespace LLama public class StatelessExecutor : ILLamaExecutor { - private readonly LLamaContext _context; - private readonly LLamaContext.State _originalState; + private readonly LLamaWeights _weights; + private readonly IModelParams _params; /// /// The context used by the executor when running the inference. /// - public LLamaContext Context => _context; + public LLamaContext Context { get; private set; } /// - /// + /// Create a new stateless executor which will use the given model /// - /// The LLama model. + /// + /// + public StatelessExecutor(LLamaWeights weights, IModelParams @params) + { + _weights = weights; + _params = @params; + + Context = _weights.CreateContext(_params); + Context.Dispose(); + } + + /// + /// Create a new stateless executor which will use the model used to create the given context + /// + /// + [Obsolete("Use the constructor which automatically creates contexts using the LLamaWeights")] public StatelessExecutor(LLamaContext context) { - _context = context; - - var tokens = context.Tokenize(" ", true).ToArray(); - _context.NativeHandle.Eval(tokens.AsSpan(0, tokens.Length), 0, _context.Params.Threads); - _originalState = context.GetState(); + _weights = new LLamaWeights(context.NativeHandle.ModelHandle, Encoding.GetEncoding(context.Params.Encoding)); + _params = context.Params; + + Context = _weights.CreateContext(_params); + Context.Dispose(); } /// public IEnumerable Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) { + using var context = _weights.CreateContext(_params); + Context = context; + + if (!Context.NativeHandle.IsClosed) + Context.Dispose(); + Context = _weights.CreateContext(Context.Params); + if (inferenceParams != null) { if (inferenceParams.TokensKeep > Context.ContextSize) @@ -58,10 +80,10 @@ namespace LLama for (var i = 0; i < inferenceParams.RepeatLastTokensCount; i++) lastTokens.Add(0); - var tokens = _context.Tokenize(text).ToList(); + var tokens = Context.Tokenize(text).ToList(); var n_prompt_tokens = tokens.Count; - _context.Eval(tokens, n_past); + Context.Eval(tokens, n_past); lastTokens.AddRange(tokens); n_past += n_prompt_tokens; @@ -71,21 +93,19 @@ namespace LLama for(var i = 0; i < max_tokens; i++) { if (cancellationToken.IsCancellationRequested) - { - _context.LoadState(_originalState); break; - } - var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? _context.ContextSize : inferenceParams.RepeatLastTokensCount; - var tokenDataArray = _context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, + var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? Context.ContextSize : inferenceParams.RepeatLastTokensCount; + + var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - var id = _context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, + var id = Context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP); lastTokens.Add(id); - var response = _context.TokenToString(id); + var response = Context.TokenToString(id); yield return response; tokens.Clear(); @@ -96,7 +116,7 @@ namespace LLama // when run out of context // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L433 - if (n_past + tokens.Count > _context.ContextSize) + if (n_past + tokens.Count > Context.ContextSize) { var n_left = n_past - inferenceParams.TokensKeep; @@ -106,10 +126,8 @@ namespace LLama tokens.AddRange(lastTokens.Skip(lastTokens.Count - n_left / 2).Take(n_left / 2)); } - n_past = _context.Eval(tokens, n_past); + n_past = Context.Eval(tokens, n_past); } - - _context.LoadState(_originalState); } /// @@ -125,7 +143,7 @@ namespace LLama var builder = new StringBuilder(); foreach (var token in tokens) - builder.Append(_context.TokenToString(token)); + builder.Append(Context.TokenToString(token)); var last_output = builder.ToString(); diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index cb237a70..7d0ba1b0 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -20,9 +20,15 @@ namespace LLama /// Be careful how you use this! public SafeLlamaModelHandle NativeHandle => _weights; - private LLamaWeights(SafeLlamaModelHandle weights) + /// + /// Encoding to use to convert text into bytes for the model + /// + public Encoding Encoding { get; } + + internal LLamaWeights(SafeLlamaModelHandle weights, Encoding encoding) { _weights = weights; + Encoding = encoding; } /// @@ -38,7 +44,7 @@ namespace LLama if (!string.IsNullOrEmpty(@params.LoraAdapter)) weights.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads); - return new LLamaWeights(weights); + return new LLamaWeights(weights, Encoding.GetEncoding(@params.Encoding)); } /// @@ -51,11 +57,10 @@ namespace LLama /// Create a llama_context using this model /// /// - /// /// - public LLamaContext CreateContext(IModelParams @params, Encoding encoding) + public LLamaContext CreateContext(IModelParams @params) { - return new LLamaContext(this, @params, encoding); + return new LLamaContext(this, @params); } } } diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index bb5911fe..2e499196 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -138,7 +138,6 @@ namespace LLama.Native /// Rows: n_tokens
/// Cols: n_vocab ///
- /// /// public Span GetLogits() { From a45d9089e1743fdf32d6ecd70fac9622a09fe7aa Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Tue, 22 Aug 2023 01:32:53 +0100 Subject: [PATCH 10/10] Fixed demos --- LLama.Examples/NewVersion/ChatSessionStripRoleName.cs | 2 +- LLama.Examples/NewVersion/ChatSessionWithRoleName.cs | 2 +- LLama.Examples/NewVersion/InstructModeExecute.cs | 2 +- LLama.Examples/NewVersion/InteractiveModeExecute.cs | 2 +- LLama.Examples/NewVersion/LoadAndSaveSession.cs | 2 +- LLama.Examples/NewVersion/LoadAndSaveState.cs | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/LLama.Examples/NewVersion/ChatSessionStripRoleName.cs b/LLama.Examples/NewVersion/ChatSessionStripRoleName.cs index 230118e5..46c0505e 100644 --- a/LLama.Examples/NewVersion/ChatSessionStripRoleName.cs +++ b/LLama.Examples/NewVersion/ChatSessionStripRoleName.cs @@ -13,7 +13,7 @@ namespace LLama.Examples.NewVersion var parameters = new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5); using var model = LLamaWeights.LoadFromFile(parameters); - using var context = model.CreateContext(parameters, Encoding.UTF8); + using var context = model.CreateContext(parameters); var executor = new InteractiveExecutor(context); var session = new ChatSession(executor).WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Bob:" }, redundancyLength: 8)); diff --git a/LLama.Examples/NewVersion/ChatSessionWithRoleName.cs b/LLama.Examples/NewVersion/ChatSessionWithRoleName.cs index a3609388..5e155252 100644 --- a/LLama.Examples/NewVersion/ChatSessionWithRoleName.cs +++ b/LLama.Examples/NewVersion/ChatSessionWithRoleName.cs @@ -13,7 +13,7 @@ namespace LLama.Examples.NewVersion var parameters = new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5); using var model = LLamaWeights.LoadFromFile(parameters); - using var context = model.CreateContext(parameters, Encoding.UTF8); + using var context = model.CreateContext(parameters); var executor = new InteractiveExecutor(context); var session = new ChatSession(executor); diff --git a/LLama.Examples/NewVersion/InstructModeExecute.cs b/LLama.Examples/NewVersion/InstructModeExecute.cs index 0a384062..d5b31416 100644 --- a/LLama.Examples/NewVersion/InstructModeExecute.cs +++ b/LLama.Examples/NewVersion/InstructModeExecute.cs @@ -13,7 +13,7 @@ namespace LLama.Examples.NewVersion var parameters = new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5); using var model = LLamaWeights.LoadFromFile(parameters); - using var context = model.CreateContext(parameters, Encoding.UTF8); + using var context = model.CreateContext(parameters); var executor = new InstructExecutor(context); Console.ForegroundColor = ConsoleColor.Yellow; diff --git a/LLama.Examples/NewVersion/InteractiveModeExecute.cs b/LLama.Examples/NewVersion/InteractiveModeExecute.cs index 9fee007f..cc7c2891 100644 --- a/LLama.Examples/NewVersion/InteractiveModeExecute.cs +++ b/LLama.Examples/NewVersion/InteractiveModeExecute.cs @@ -13,7 +13,7 @@ namespace LLama.Examples.NewVersion var parameters = new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5); using var model = LLamaWeights.LoadFromFile(parameters); - using var context = model.CreateContext(parameters, Encoding.UTF8); + using var context = model.CreateContext(parameters); var ex = new InteractiveExecutor(context); Console.ForegroundColor = ConsoleColor.Yellow; diff --git a/LLama.Examples/NewVersion/LoadAndSaveSession.cs b/LLama.Examples/NewVersion/LoadAndSaveSession.cs index 5e5c4252..9ca07de0 100644 --- a/LLama.Examples/NewVersion/LoadAndSaveSession.cs +++ b/LLama.Examples/NewVersion/LoadAndSaveSession.cs @@ -13,7 +13,7 @@ namespace LLama.Examples.NewVersion var parameters = new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5); using var model = LLamaWeights.LoadFromFile(parameters); - using var context = model.CreateContext(parameters, Encoding.UTF8); + using var context = model.CreateContext(parameters); var ex = new InteractiveExecutor(context); var session = new ChatSession(ex); diff --git a/LLama.Examples/NewVersion/LoadAndSaveState.cs b/LLama.Examples/NewVersion/LoadAndSaveState.cs index 1a1c0d88..d72fade8 100644 --- a/LLama.Examples/NewVersion/LoadAndSaveState.cs +++ b/LLama.Examples/NewVersion/LoadAndSaveState.cs @@ -13,7 +13,7 @@ namespace LLama.Examples.NewVersion var parameters = new ModelParams(modelPath, contextSize: 1024, seed: 1337, gpuLayerCount: 5); using var model = LLamaWeights.LoadFromFile(parameters); - using var context = model.CreateContext(parameters, Encoding.UTF8); + using var context = model.CreateContext(parameters); var ex = new InteractiveExecutor(context); Console.ForegroundColor = ConsoleColor.Yellow;