diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index eec9a9d3..2a74bb29 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; using System.Runtime.CompilerServices; +using System.Text.Json; using System.Threading; using System.Threading.Tasks; using LLama.Abstractions; @@ -47,6 +48,27 @@ public class ChatSession /// public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform(); + /// + /// Create a new chat session and preprocess history. + /// + /// The executor for this session + /// History for this session + /// Cancellation token to stop session pre-processing + /// + public static async Task InitializeSessionFromHistoryAsync( + ILLamaExecutor executor, + ChatHistory history, + CancellationToken cancellationToken = default) + { + if (executor is not StatefulExecutorBase statefulExecutor) + { + throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor)); + } + var session = new ChatSession(executor, history); + await statefulExecutor.AddPromptAsync(session.HistoryTransform.HistoryToText(history), cancellationToken); + return session; + } + /// /// Create a new chat session. /// @@ -144,7 +166,7 @@ public class ChatSession { return new() { - ExecutorState = ((StatefulExecutorBase)Executor).GetStateData(), + ExecutorState = JsonSerializer.Serialize(((StatefulExecutorBase)Executor).GetStateData()), ContextState = Executor.Context.GetState(), InputTransformPipeline = InputTransformPipeline, OutputTransform = OutputTransform, @@ -169,7 +191,11 @@ public class ChatSession } else { - statefulExecutor.LoadState(state.ExecutorState); + statefulExecutor.LoadState( + JsonSerializer.Deserialize( + state.ExecutorState, statefulExecutor.GetStateData().GetType() + ) as ExecutorBaseState ?? throw new ArgumentException("Executor state is invalid", nameof(state)) + ); } } else @@ -260,6 +286,33 @@ public class ChatSession return this; } + + /// + /// Compute KV cache for the system message and add it to the chat history. + /// + /// + /// + public async Task ProcessSystemMessage(string content) + { + if (Executor is not StatefulExecutorBase statefulExecutor) + { + throw new InvalidOperationException("Executor must be a StatefulExecutorBase to support pre-processing of system messages."); + } + if (History.Messages.Count > 0) + { + throw new ArgumentException("Cannot add a system message after another message", nameof(content)); + } + foreach (var inputTransform in InputTransformPipeline) + { + content = inputTransform.Transform(content); + } + + await statefulExecutor.AddPromptAsync(content); + + History.AddMessage(AuthorRole.System, content); + return this; + } + /// /// Add a system message to the chat history. /// @@ -557,9 +610,9 @@ public class ChatSession public record SessionState { /// - /// Saved executor state for the session. + /// Saved executor state for the session in JSON format. /// - public ExecutorBaseState? ExecutorState { get; init; } + public string? ExecutorState { get; init; } /// /// Saved context state (KV cache) for the session. diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index b52b6e54..dc0508e5 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -222,7 +222,7 @@ namespace LLama /// public void ResetState() { - LoadState(_emptyState); + NativeApi.llama_kv_cache_clear(NativeHandle); } /// diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index f96b11bd..b9a0b412 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -323,6 +323,34 @@ namespace LLama } } + /// + /// Asynchronously runs a prompt through the model to compute KV cache without generating any new tokens. + /// + /// Prompt to process + /// A cancellation token + /// + public virtual async Task AddPromptAsync(string prompt, CancellationToken cancellationToken = default) + { + var inferenceParams = new InferenceParams + { + MaxTokens = 0 + }; + var args = new InferStateArgs + { + Antiprompts = new List(), + RemainedTokens = 0, + ReturnValue = false, + WaitForInput = true, + NeedToSaveSession = false + }; + + await PreprocessInputs(prompt, args); + // First run adds the prompt to the _embeds + await InferInternal(inferenceParams, args); + // Second run puts it through decode + await InferInternal(inferenceParams, args); + } + /// /// State arguments that are used in single inference ///