diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs
index eec9a9d3..2a74bb29 100644
--- a/LLama/ChatSession.cs
+++ b/LLama/ChatSession.cs
@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
+using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using LLama.Abstractions;
@@ -47,6 +48,27 @@ public class ChatSession
///
public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform();
+ ///
+ /// Create a new chat session and preprocess history.
+ ///
+ /// The executor for this session
+ /// History for this session
+ /// Cancellation token to stop session pre-processing
+ ///
+ public static async Task InitializeSessionFromHistoryAsync(
+ ILLamaExecutor executor,
+ ChatHistory history,
+ CancellationToken cancellationToken = default)
+ {
+ if (executor is not StatefulExecutorBase statefulExecutor)
+ {
+ throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor));
+ }
+ var session = new ChatSession(executor, history);
+ await statefulExecutor.AddPromptAsync(session.HistoryTransform.HistoryToText(history), cancellationToken);
+ return session;
+ }
+
///
/// Create a new chat session.
///
@@ -144,7 +166,7 @@ public class ChatSession
{
return new()
{
- ExecutorState = ((StatefulExecutorBase)Executor).GetStateData(),
+ ExecutorState = JsonSerializer.Serialize(((StatefulExecutorBase)Executor).GetStateData()),
ContextState = Executor.Context.GetState(),
InputTransformPipeline = InputTransformPipeline,
OutputTransform = OutputTransform,
@@ -169,7 +191,11 @@ public class ChatSession
}
else
{
- statefulExecutor.LoadState(state.ExecutorState);
+ statefulExecutor.LoadState(
+ JsonSerializer.Deserialize(
+ state.ExecutorState, statefulExecutor.GetStateData().GetType()
+ ) as ExecutorBaseState ?? throw new ArgumentException("Executor state is invalid", nameof(state))
+ );
}
}
else
@@ -260,6 +286,33 @@ public class ChatSession
return this;
}
+
+ ///
+ /// Compute KV cache for the system message and add it to the chat history.
+ ///
+ ///
+ ///
+ public async Task ProcessSystemMessage(string content)
+ {
+ if (Executor is not StatefulExecutorBase statefulExecutor)
+ {
+ throw new InvalidOperationException("Executor must be a StatefulExecutorBase to support pre-processing of system messages.");
+ }
+ if (History.Messages.Count > 0)
+ {
+ throw new ArgumentException("Cannot add a system message after another message", nameof(content));
+ }
+ foreach (var inputTransform in InputTransformPipeline)
+ {
+ content = inputTransform.Transform(content);
+ }
+
+ await statefulExecutor.AddPromptAsync(content);
+
+ History.AddMessage(AuthorRole.System, content);
+ return this;
+ }
+
///
/// Add a system message to the chat history.
///
@@ -557,9 +610,9 @@ public class ChatSession
public record SessionState
{
///
- /// Saved executor state for the session.
+ /// Saved executor state for the session in JSON format.
///
- public ExecutorBaseState? ExecutorState { get; init; }
+ public string? ExecutorState { get; init; }
///
/// Saved context state (KV cache) for the session.
diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs
index b52b6e54..dc0508e5 100644
--- a/LLama/LLamaContext.cs
+++ b/LLama/LLamaContext.cs
@@ -222,7 +222,7 @@ namespace LLama
///
public void ResetState()
{
- LoadState(_emptyState);
+ NativeApi.llama_kv_cache_clear(NativeHandle);
}
///
diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs
index f96b11bd..b9a0b412 100644
--- a/LLama/LLamaExecutorBase.cs
+++ b/LLama/LLamaExecutorBase.cs
@@ -323,6 +323,34 @@ namespace LLama
}
}
+ ///
+ /// Asynchronously runs a prompt through the model to compute KV cache without generating any new tokens.
+ ///
+ /// Prompt to process
+ /// A cancellation token
+ ///
+ public virtual async Task AddPromptAsync(string prompt, CancellationToken cancellationToken = default)
+ {
+ var inferenceParams = new InferenceParams
+ {
+ MaxTokens = 0
+ };
+ var args = new InferStateArgs
+ {
+ Antiprompts = new List(),
+ RemainedTokens = 0,
+ ReturnValue = false,
+ WaitForInput = true,
+ NeedToSaveSession = false
+ };
+
+ await PreprocessInputs(prompt, args);
+ // First run adds the prompt to the _embeds
+ await InferInternal(inferenceParams, args);
+ // Second run puts it through decode
+ await InferInternal(inferenceParams, args);
+ }
+
///
/// State arguments that are used in single inference
///