@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using LLama.Abstractions;
@@ -47,6 +48,27 @@ public class ChatSession
/// </summary>
public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform();
/// <summary>
/// Create a new chat session and preprocess history.
/// </summary>
/// <param name="executor">The executor for this session</param>
/// <param name="history">History for this session</param>
/// <param name="cancellationToken">Cancellation token to stop session pre-processing</param>
/// <returns></returns>
public static async Task<ChatSession> InitializeSessionFromHistoryAsync(
ILLamaExecutor executor,
ChatHistory history,
CancellationToken cancellationToken = default)
{
if (executor is not StatefulExecutorBase statefulExecutor)
{
throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor));
}
var session = new ChatSession(executor, history);
await statefulExecutor.AddPromptAsync(session.HistoryTransform.HistoryToText(history), cancellationToken);
return session;
}
/// <summary>
/// Create a new chat session.
/// </summary>
@@ -144,7 +166,7 @@ public class ChatSession
{
return new()
{
ExecutorState = ((StatefulExecutorBase)Executor).GetStateData(),
ExecutorState = JsonSerializer.Serialize( ((StatefulExecutorBase)Executor).GetStateData() ),
ContextState = Executor.Context.GetState(),
InputTransformPipeline = InputTransformPipeline,
OutputTransform = OutputTransform,
@@ -169,7 +191,11 @@ public class ChatSession
}
else
{
statefulExecutor.LoadState(state.ExecutorState);
statefulExecutor.LoadState(
JsonSerializer.Deserialize(
state.ExecutorState, statefulExecutor.GetStateData().GetType()
) as ExecutorBaseState ?? throw new ArgumentException("Executor state is invalid", nameof(state))
);
}
}
else
@@ -260,6 +286,33 @@ public class ChatSession
return this;
}
/// <summary>
/// Compute KV cache for the system message and add it to the chat history.
/// </summary>
/// <param name="content"></param>
/// <returns></returns>
public async Task<ChatSession> ProcessSystemMessage(string content)
{
if (Executor is not StatefulExecutorBase statefulExecutor)
{
throw new InvalidOperationException("Executor must be a StatefulExecutorBase to support pre-processing of system messages.");
}
if (History.Messages.Count > 0)
{
throw new ArgumentException("Cannot add a system message after another message", nameof(content));
}
foreach (var inputTransform in InputTransformPipeline)
{
content = inputTransform.Transform(content);
}
await statefulExecutor.AddPromptAsync(content);
History.AddMessage(AuthorRole.System, content);
return this;
}
/// <summary>
/// Add a system message to the chat history.
/// </summary>
@@ -557,9 +610,9 @@ public class ChatSession
public record SessionState
{
/// <summary>
/// Saved executor state for the session.
/// Saved executor state for the session in JSON format .
/// </summary>
public ExecutorBaseState ? ExecutorState { get; init; }
public string ? ExecutorState { get; init; }
/// <summary>
/// Saved context state (KV cache) for the session.