Browse Source

AddPromptAsync method for stateful executors, Chat session initialize from history and process system message methods for pre-processing prompts. Serializing executor state to JSON, to avoid saved states from being updated by reference.

tags/0.11.0
eublefar 1 year ago
parent
commit
b2f7dbb39b
3 changed files with 86 additions and 5 deletions
  1. +57
    -4
      LLama/ChatSession.cs
  2. +1
    -1
      LLama/LLamaContext.cs
  3. +28
    -0
      LLama/LLamaExecutorBase.cs

+ 57
- 4
LLama/ChatSession.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using LLama.Abstractions;
@@ -47,6 +48,27 @@ public class ChatSession
/// </summary>
public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform();

/// <summary>
/// Create a new chat session and preprocess history.
/// </summary>
/// <param name="executor">The executor for this session</param>
/// <param name="history">History for this session</param>
/// <param name="cancellationToken">Cancellation token to stop session pre-processing</param>
/// <returns></returns>
public static async Task<ChatSession> InitializeSessionFromHistoryAsync(
ILLamaExecutor executor,
ChatHistory history,
CancellationToken cancellationToken = default)
{
if (executor is not StatefulExecutorBase statefulExecutor)
{
throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor));
}
var session = new ChatSession(executor, history);
await statefulExecutor.AddPromptAsync(session.HistoryTransform.HistoryToText(history), cancellationToken);
return session;
}

/// <summary>
/// Create a new chat session.
/// </summary>
@@ -144,7 +166,7 @@ public class ChatSession
{
return new()
{
ExecutorState = ((StatefulExecutorBase)Executor).GetStateData(),
ExecutorState = JsonSerializer.Serialize(((StatefulExecutorBase)Executor).GetStateData()),
ContextState = Executor.Context.GetState(),
InputTransformPipeline = InputTransformPipeline,
OutputTransform = OutputTransform,
@@ -169,7 +191,11 @@ public class ChatSession
}
else
{
statefulExecutor.LoadState(state.ExecutorState);
statefulExecutor.LoadState(
JsonSerializer.Deserialize(
state.ExecutorState, statefulExecutor.GetStateData().GetType()
) as ExecutorBaseState ?? throw new ArgumentException("Executor state is invalid", nameof(state))
);
}
}
else
@@ -260,6 +286,33 @@ public class ChatSession
return this;
}

/// <summary>
/// Compute KV cache for the system message and add it to the chat history.
/// </summary>
/// <param name="content"></param>
/// <returns></returns>
public async Task<ChatSession> ProcessSystemMessage(string content)
{
if (Executor is not StatefulExecutorBase statefulExecutor)
{
throw new InvalidOperationException("Executor must be a StatefulExecutorBase to support pre-processing of system messages.");
}
if (History.Messages.Count > 0)
{
throw new ArgumentException("Cannot add a system message after another message", nameof(content));
}
foreach (var inputTransform in InputTransformPipeline)
{
content = inputTransform.Transform(content);
}

await statefulExecutor.AddPromptAsync(content);

History.AddMessage(AuthorRole.System, content);
return this;
}

/// <summary>
/// Add a system message to the chat history.
/// </summary>
@@ -557,9 +610,9 @@ public class ChatSession
public record SessionState
{
/// <summary>
/// Saved executor state for the session.
/// Saved executor state for the session in JSON format.
/// </summary>
public ExecutorBaseState? ExecutorState { get; init; }
public string? ExecutorState { get; init; }

/// <summary>
/// Saved context state (KV cache) for the session.


+ 1
- 1
LLama/LLamaContext.cs View File

@@ -222,7 +222,7 @@ namespace LLama
/// <exception cref="RuntimeError"></exception>
public void ResetState()
{
LoadState(_emptyState);
NativeApi.llama_kv_cache_clear(NativeHandle);
}

/// <summary>


+ 28
- 0
LLama/LLamaExecutorBase.cs View File

@@ -323,6 +323,34 @@ namespace LLama
}
}

/// <summary>
/// Asynchronously runs a prompt through the model to compute KV cache without generating any new tokens.
/// </summary>
/// <param name="prompt">Prompt to process</param>
/// <param name="cancellationToken">A cancellation token</param>
/// <returns></returns>
public virtual async Task AddPromptAsync(string prompt, CancellationToken cancellationToken = default)
{
var inferenceParams = new InferenceParams
{
MaxTokens = 0
};
var args = new InferStateArgs
{
Antiprompts = new List<string>(),
RemainedTokens = 0,
ReturnValue = false,
WaitForInput = true,
NeedToSaveSession = false
};

await PreprocessInputs(prompt, args);
// First run adds the prompt to the _embeds
await InferInternal(inferenceParams, args);
// Second run puts it through decode
await InferInternal(inferenceParams, args);
}

/// <summary>
/// State arguments that are used in single inference
/// </summary>


Loading…
Cancel
Save