Browse Source

Merge pull request #560 from eublefar/feature/chat-session-state-management

Chat session state management
tags/0.11.0
Rinne GitHub 2 years ago
parent
commit
b677cdc6a3
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
14 changed files with 680 additions and 57 deletions
  1. +1
    -0
      LLama.Examples/ExampleRunner.cs
  2. +12
    -0
      LLama.Examples/Examples/ChatSessionWithHistory.cs
  3. +107
    -0
      LLama.Examples/Examples/ChatSessionWithRestart.cs
  4. +8
    -0
      LLama/Abstractions/IHistoryTransform.cs
  5. +10
    -1
      LLama/Abstractions/ITextStreamTransform.cs
  6. +11
    -1
      LLama/Abstractions/ITextTransform.cs
  7. +325
    -38
      LLama/ChatSession.cs
  8. +10
    -0
      LLama/Common/ChatHistory.cs
  9. +57
    -0
      LLama/Common/PolymorphicJSONConverter.cs
  10. +28
    -2
      LLama/LLamaContext.cs
  11. +32
    -3
      LLama/LLamaExecutorBase.cs
  12. +6
    -6
      LLama/LLamaInstructExecutor.cs
  13. +6
    -6
      LLama/LLamaInteractExecutor.cs
  14. +67
    -0
      LLama/LLamaTransforms.cs

+ 1
- 0
LLama.Examples/ExampleRunner.cs View File

@@ -8,6 +8,7 @@ public class ExampleRunner
{ "Chat Session: History", ChatSessionWithHistory.Run },
{ "Chat Session: Role names", ChatSessionWithRoleName.Run },
{ "Chat Session: Role names stripped", ChatSessionStripRoleName.Run },
{ "Chat Session: Pre-processing and reset", ChatSessionWithRestart.Run },
{ "Chat Session: Coding Assistant", CodingAssistant.Run },
{ "Chat Session: Automatic conversation", TalkToYourself.Run },
{ "Chat Session: Chinese characters", ChatChineseGB2312.Run },


+ 12
- 0
LLama.Examples/Examples/ChatSessionWithHistory.cs View File

@@ -48,6 +48,10 @@ public class ChatSessionWithHistory

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The chat session has started.");
Console.WriteLine("Type 'exit' to end the chat session.");
Console.WriteLine("Type 'save' to save the chat session to disk.");
Console.WriteLine("Type 'load' to load the chat session from disk.");
Console.WriteLine("Type 'regenerate' to regenerate the last response.");

// show the prompt
Console.ForegroundColor = ConsoleColor.Green;
@@ -55,12 +59,20 @@ public class ChatSessionWithHistory

while (userInput != "exit")
{
// Save the chat state to disk
if (userInput == "save")
{
session.SaveSession("Assets/chat-with-bob");
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Session saved.");
}
// Load the chat state from disk
else if (userInput == "load")
{
session.LoadSession("Assets/chat-with-bob");
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Session loaded.");
}
else if (userInput == "regenerate")
{
Console.ForegroundColor = ConsoleColor.Yellow;


+ 107
- 0
LLama.Examples/Examples/ChatSessionWithRestart.cs View File

@@ -0,0 +1,107 @@
using LLama.Common;

namespace LLama.Examples.Examples;

public class ChatSessionWithRestart
{
public static async Task Run()
{
string modelPath = UserSettings.GetModelPath();

var parameters = new ModelParams(modelPath)
{
ContextSize = 1024,
Seed = 1337,
GpuLayerCount = 5
};
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);

var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
ChatSession prototypeSession =
await ChatSession.InitializeSessionFromHistoryAsync(executor, chatHistory);
prototypeSession.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(
new string[] { "User:", "Assistant:" },
redundancyLength: 8));
var resetState = prototypeSession.GetSessionState();

ChatSession session = new ChatSession(executor);
session.LoadSession(resetState);

InferenceParams inferenceParams = new InferenceParams()
{
Temperature = 0.9f,
AntiPrompts = new List<string> { "User:" }
};

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The chat session has started. Starting point saved.");
Console.WriteLine("Type 'exit' to end the chat session.");
Console.WriteLine("Type 'save' to save chat session state in memory.");
Console.WriteLine("Type 'reset' to reset the chat session to its saved state.");
Console.WriteLine("Type 'answer for assistant' to add and process provided user and assistant messages.");

// show the prompt
Console.ForegroundColor = ConsoleColor.Green;
string userInput = Console.ReadLine() ?? "";

while (userInput != "exit")
{
// Load the session state from the reset state
if(userInput == "reset")
{
session.LoadSession(resetState);
Console.WriteLine($"Reset to history:\n{session.HistoryTransform.HistoryToText(session.History)}");
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Session reset.");
}
// Assign new reset state.
else if (userInput == "save")
{
resetState = session.GetSessionState();
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Session saved.");
}
// Provide user and override assistant answer with your own.
else if (userInput == "answer for assistant")
{
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Provide user input: ");

Console.ForegroundColor = ConsoleColor.Green;
string userInputOverride = Console.ReadLine() ?? "";

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Provide assistant input: ");
Console.ForegroundColor = ConsoleColor.Green;
string assistantInputOverride = Console.ReadLine() ?? "";
await session.AddAndProcessUserMessage(userInputOverride);
await session.AddAndProcessAssistantMessage(assistantInputOverride);

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("User and assistant messages processed. Provide next user message:");
}
else
{
await foreach (
var text
in session.ChatAsync(
new ChatHistory.Message(AuthorRole.User, userInput),
inferenceParams))
{
Console.ForegroundColor = ConsoleColor.White;
Console.Write(text);
}
}

Console.ForegroundColor = ConsoleColor.Green;
userInput = Console.ReadLine() ?? "";

Console.ForegroundColor = ConsoleColor.White;
}
}
}

+ 8
- 0
LLama/Abstractions/IHistoryTransform.cs View File

@@ -1,10 +1,12 @@
using LLama.Common;
using System.Text.Json.Serialization;

namespace LLama.Abstractions
{
/// <summary>
/// Transform history to plain text and vice versa.
/// </summary>
[JsonConverter(typeof(PolymorphicJSONConverter<IHistoryTransform>))]
public interface IHistoryTransform
{
/// <summary>
@@ -21,5 +23,11 @@ namespace LLama.Abstractions
/// <param name="text">The chat history as plain text.</param>
/// <returns>The updated history.</returns>
ChatHistory TextToHistory(AuthorRole role, string text);

/// <summary>
/// Copy the transform.
/// </summary>
/// <returns></returns>
IHistoryTransform Clone();
}
}

+ 10
- 1
LLama/Abstractions/ITextStreamTransform.cs View File

@@ -1,10 +1,13 @@
using System.Collections.Generic;
using LLama.Common;
using System.Collections.Generic;
using System.Text.Json.Serialization;

namespace LLama.Abstractions
{
/// <summary>
/// Takes a stream of tokens and transforms them.
/// </summary>
[JsonConverter(typeof(PolymorphicJSONConverter<ITextStreamTransform>))]
public interface ITextStreamTransform
{
/// <summary>
@@ -13,5 +16,11 @@ namespace LLama.Abstractions
/// <param name="tokens"></param>
/// <returns></returns>
IAsyncEnumerable<string> TransformAsync(IAsyncEnumerable<string> tokens);

/// <summary>
/// Copy the transform.
/// </summary>
/// <returns></returns>
ITextStreamTransform Clone();
}
}

+ 11
- 1
LLama/Abstractions/ITextTransform.cs View File

@@ -1,4 +1,7 @@
namespace LLama.Abstractions
using System.Text.Json.Serialization;
using LLama.Common;

namespace LLama.Abstractions
{
/// <summary>
/// An interface for text transformations.
@@ -9,6 +12,7 @@
/// - Trimming
/// - etc.
/// </summary>
[JsonConverter(typeof(PolymorphicJSONConverter<ITextTransform>))]
public interface ITextTransform
{
/// <summary>
@@ -17,5 +21,11 @@
/// <param name="text"></param>
/// <returns></returns>
string Transform(string text);

/// <summary>
/// Copy the transform.
/// </summary>
/// <returns></returns>
ITextTransform Clone();
}
}

+ 325
- 38
LLama/ChatSession.cs View File

@@ -3,11 +3,14 @@ using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using LLama.Abstractions;
using LLama.Common;
using static LLama.InteractiveExecutor;
using static LLama.LLamaContext;
using static LLama.StatefulExecutorBase;

namespace LLama;

@@ -16,9 +19,30 @@ namespace LLama;
/// </summary>
public class ChatSession
{
private const string _modelStateFilename = "ModelState.st";
private const string _executorStateFilename = "ExecutorState.json";
private const string _hsitoryFilename = "ChatHistory.json";
/// <summary>
/// The filename for the serialized model state (KV cache, etc).
/// </summary>
public const string MODEL_STATE_FILENAME = "ModelState.st";
/// <summary>
/// The filename for the serialized executor state.
/// </summary>
public const string EXECUTOR_STATE_FILENAME = "ExecutorState.json";
/// <summary>
/// The filename for the serialized chat history.
/// </summary>
public const string HISTORY_STATE_FILENAME = "ChatHistory.json";
/// <summary>
/// The filename for the serialized input transform pipeline.
/// </summary>
public const string INPUT_TRANSFORM_FILENAME = "InputTransform.json";
/// <summary>
/// The filename for the serialized output transform.
/// </summary>
public const string OUTPUT_TRANSFORM_FILENAME = "OutputTransform.json";
/// <summary>
/// The filename for the serialized history transform.
/// </summary>
public const string HISTORY_TRANSFORM_FILENAME = "HistoryTransform.json";

/// <summary>
/// The executor for this session.
@@ -45,6 +69,24 @@ public class ChatSession
/// </summary>
public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform();

/// <summary>
/// Create a new chat session and preprocess history.
/// </summary>
/// <param name="executor">The executor for this session</param>
/// <param name="history">History for this session</param>
/// <returns></returns>
public static async Task<ChatSession> InitializeSessionFromHistoryAsync(
ILLamaExecutor executor, ChatHistory history)
{
if (executor is not StatefulExecutorBase statefulExecutor)
{
throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor));
}
var session = new ChatSession(executor, history);
await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history));
return session;
}

/// <summary>
/// Create a new chat session.
/// </summary>
@@ -112,56 +154,76 @@ public class ChatSession
/// <exception cref="ArgumentException"></exception>
public void SaveSession(string path)
{
if (string.IsNullOrWhiteSpace(path))
GetSessionState().Save(path);
}

/// <summary>
/// Get the session state.
/// </summary>
/// <returns>SessionState object representing session state in-memory</returns>
public SessionState GetSessionState()
{
var executorState = ((StatefulExecutorBase)Executor).GetStateData();
return new SessionState(
executorState.PastTokensCount > 0
? Executor.Context.GetState() : null,
executorState,
History,
InputTransformPipeline,
OutputTransform,
HistoryTransform);
}

/// <summary>
/// Load a session from a session state.
/// </summary>
/// <param name="state"></param>
/// <param name="loadTransforms">If true loads transforms saved in the session state.</param>
/// <returns></returns>
/// <exception cref="ArgumentException"></exception>
public void LoadSession(SessionState state, bool loadTransforms = true)
{
if (Executor is StatefulExecutorBase statefulExecutor)
{
throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
if (state.ExecutorState is not null)
{
statefulExecutor.LoadState(state.ExecutorState);
}
}

if (Directory.Exists(path))
if (state.ContextState is null)
{
Directory.Delete(path, recursive: true);
Executor.Context.NativeHandle.KvCacheClear();
}
else
{
Executor.Context.LoadState(state.ContextState);
}
History = new ChatHistory(state.History);
if (loadTransforms)
{
InputTransformPipeline = state.InputTransformPipeline.Select(t => t.Clone()).ToList();
OutputTransform = state.OutputTransform.Clone();
HistoryTransform = state.HistoryTransform.Clone();
}

Directory.CreateDirectory(path);

string modelStateFilePath = Path.Combine(path, _modelStateFilename);
Executor.Context.SaveState(modelStateFilePath);

string executorStateFilepath = Path.Combine(path, _executorStateFilename);
((StatefulExecutorBase)Executor).SaveState(executorStateFilepath);

string historyFilepath = Path.Combine(path, _hsitoryFilename);
File.WriteAllText(historyFilepath, History.ToJson());
}

/// <summary>
/// Load a session from a directory.
/// </summary>
/// <param name="path"></param>
/// <param name="loadTransforms">If true loads transforms saved in the session state.</param>
/// <returns></returns>
/// <exception cref="ArgumentException"></exception>
public void LoadSession(string path)
public void LoadSession(string path, bool loadTransforms = true)
{
if (string.IsNullOrWhiteSpace(path))
var state = SessionState.Load(path);
// Handle non-polymorphic serialization of executor state
if (state.ExecutorState is null)
{
throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
}

if (!Directory.Exists(path))
{
throw new ArgumentException("Directory does not exist", nameof(path));
var executorPath = Path.Combine(path, EXECUTOR_STATE_FILENAME);
((StatefulExecutorBase) Executor).LoadState(filename: executorPath);
}

string modelStateFilePath = Path.Combine(path, _modelStateFilename);
Executor.Context.LoadState(modelStateFilePath);

string executorStateFilepath = Path.Combine(path, _executorStateFilename);
((StatefulExecutorBase)Executor).LoadState(executorStateFilepath);

string historyFilepath = Path.Combine(path, _hsitoryFilename);
string historyJson = File.ReadAllText(historyFilepath);
History = ChatHistory.FromJson(historyJson)
?? throw new ArgumentException("History file is invalid", nameof(path));
LoadSession(state, loadTransforms);
}

/// <summary>
@@ -238,6 +300,49 @@ public class ChatSession
return this;
}

/// <summary>
/// Compute KV cache for the message and add it to the chat history.
/// </summary>
/// <param name="message"></param>
/// <returns></returns>
public async Task<ChatSession> AddAndProcessMessage(ChatHistory.Message message)
{
if (Executor is not StatefulExecutorBase statefulExecutor)
{
throw new InvalidOperationException("Executor must be a StatefulExecutorBase to support pre-processing of system messages.");
}
AddMessage(message);
var content = message.Content;
if (message.AuthorRole != AuthorRole.Assistant)
{
foreach (var inputTransform in InputTransformPipeline)
{
content = inputTransform.Transform(content);
}
}

await statefulExecutor.PrefillPromptAsync(content);
return this;
}

/// <summary>
/// Compute KV cache for the system message and add it to the chat history.
/// </summary>
public Task<ChatSession> AddAndProcessSystemMessage(string content)
=> AddAndProcessMessage(new ChatHistory.Message(AuthorRole.System, content));

/// <summary>
/// Compute KV cache for the user message and add it to the chat history.
/// </summary>
public Task<ChatSession> AddAndProcessUserMessage(string content)
=> AddAndProcessMessage(new ChatHistory.Message(AuthorRole.User, content));

/// <summary>
/// Compute KV cache for the assistant message and add it to the chat history.
/// </summary>
public Task<ChatSession> AddAndProcessAssistantMessage(string content)
=> AddAndProcessMessage(new ChatHistory.Message(AuthorRole.Assistant, content));

/// <summary>
/// Replace a user message with a new message and remove all messages after the new message.
/// This is useful when the user wants to edit a message. And regenerate the response.
@@ -494,3 +599,185 @@ public class ChatSession
}
}
}

/// <summary>
/// The state of a chat session in-memory.
/// </summary>
public record SessionState
{
/// <summary>
/// Saved executor state for the session in JSON format.
/// </summary>
public ExecutorBaseState? ExecutorState { get; set; }

/// <summary>
/// Saved context state (KV cache) for the session.
/// </summary>
public State? ContextState { get; set; }

/// <summary>
/// The input transform pipeline used in this session.
/// </summary>
public ITextTransform[] InputTransformPipeline { get; set; } = Array.Empty<ITextTransform>();

/// <summary>
/// The output transform used in this session.
/// </summary>
public ITextStreamTransform OutputTransform { get; set; } = new LLamaTransforms.EmptyTextOutputStreamTransform();

/// <summary>
/// The history transform used in this session.
/// </summary>
public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();
/// <summary>
/// The the chat history messages for this session.
/// </summary>
public ChatHistory.Message[] History { get; set; } = Array.Empty<ChatHistory.Message>();

/// <summary>
/// Create a new session state.
/// </summary>
/// <param name="contextState"></param>
/// <param name="executorState"></param>
/// <param name="history"></param>
/// <param name="inputTransformPipeline"></param>
/// <param name="outputTransform"></param>
/// <param name="historyTransform"></param>
public SessionState(
State? contextState, ExecutorBaseState executorState,
ChatHistory history, List<ITextTransform> inputTransformPipeline,
ITextStreamTransform outputTransform, IHistoryTransform historyTransform)
{
ContextState = contextState;
ExecutorState = executorState;
History = history.Messages.ToArray();
InputTransformPipeline = inputTransformPipeline.Select(t => t.Clone()).ToArray();
OutputTransform = outputTransform.Clone();
HistoryTransform = historyTransform.Clone();
}

/// <summary>
/// Save the session state to folder.
/// </summary>
/// <param name="path"></param>
public void Save(string path)
{
if (string.IsNullOrWhiteSpace(path))
{
throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
}

if (Directory.Exists(path))
{
Directory.Delete(path, recursive: true);
}

Directory.CreateDirectory(path);

string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME);
var bytes = ContextState?.ToByteArray();
if (bytes is not null)
{
File.WriteAllBytes(modelStateFilePath, bytes);
}

string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME);
File.WriteAllText(executorStateFilepath, JsonSerializer.Serialize(ExecutorState));

string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME);
File.WriteAllText(historyFilepath, new ChatHistory(History).ToJson());

string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME);
File.WriteAllText(inputTransformFilepath, JsonSerializer.Serialize(InputTransformPipeline));

string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME);
File.WriteAllText(outputTransformFilepath, JsonSerializer.Serialize(OutputTransform));

string historyTransformFilepath = Path.Combine(path, ChatSession.HISTORY_TRANSFORM_FILENAME);
File.WriteAllText(historyTransformFilepath, JsonSerializer.Serialize(HistoryTransform));
}

/// <summary>
/// Load the session state from folder.
/// </summary>
/// <param name="path"></param>
/// <returns></returns>
/// <exception cref="ArgumentException">Throws when session state is incorrect</exception>
public static SessionState Load(string path)
{
if (string.IsNullOrWhiteSpace(path))
{
throw new ArgumentException("Path cannot be null or whitespace", nameof(path));
}

if (!Directory.Exists(path))
{
throw new ArgumentException("Directory does not exist", nameof(path));
}

string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME);
var contextState = File.Exists(modelStateFilePath) ?
State.FromByteArray(File.ReadAllBytes(modelStateFilePath))
: null;

string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME);
var executorState = JsonSerializer.Deserialize<ExecutorBaseState>(File.ReadAllText(executorStateFilepath));

string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME);
string historyJson = File.ReadAllText(historyFilepath);
var history = ChatHistory.FromJson(historyJson)
?? throw new ArgumentException("History file is invalid", nameof(path));

string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME);
ITextTransform[] inputTransforms;
try
{
inputTransforms = File.Exists(inputTransformFilepath) ?
(JsonSerializer.Deserialize<ITextTransform[]>(File.ReadAllText(inputTransformFilepath))
?? throw new ArgumentException("Input transform file is invalid", nameof(path)))
: Array.Empty<ITextTransform>();
}
catch (JsonException)
{
throw new ArgumentException("Input transform file is invalid", nameof(path));
}

string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME);
ITextStreamTransform outputTransform;
try
{
outputTransform = File.Exists(outputTransformFilepath) ?
(JsonSerializer.Deserialize<ITextStreamTransform>(File.ReadAllText(outputTransformFilepath))
?? throw new ArgumentException("Output transform file is invalid", nameof(path)))
: new LLamaTransforms.EmptyTextOutputStreamTransform();
}
catch (JsonException)
{
throw new ArgumentException("Output transform file is invalid", nameof(path));
}

string historyTransformFilepath = Path.Combine(path, ChatSession.HISTORY_TRANSFORM_FILENAME);
IHistoryTransform historyTransform;
try
{
historyTransform = File.Exists(historyTransformFilepath) ?
(JsonSerializer.Deserialize<IHistoryTransform>(File.ReadAllText(historyTransformFilepath))
?? throw new ArgumentException("History transform file is invalid", nameof(path)))
: new LLamaTransforms.DefaultHistoryTransform();
}
catch (JsonException)
{
throw new ArgumentException("History transform file is invalid", nameof(path));
}

return new SessionState(
contextState,
executorState,
history,
inputTransforms.ToList(),
outputTransform,
historyTransform);
}
}

+ 10
- 0
LLama/Common/ChatHistory.cs View File

@@ -1,4 +1,5 @@
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;

@@ -80,6 +81,15 @@ namespace LLama.Common
[JsonConstructor]
public ChatHistory() { }

/// <summary>
/// Create a new instance of the chat history from array of messages
/// </summary>
/// <param name="messageHistory"></param>
public ChatHistory(Message[] messageHistory)
{
this.Messages = messageHistory.ToList();
}

/// <summary>
/// Add a message to the chat history
/// </summary>


+ 57
- 0
LLama/Common/PolymorphicJSONConverter.cs View File

@@ -0,0 +1,57 @@
using LLama.Abstractions;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;

namespace LLama.Common
{
internal class PolymorphicJSONConverter<T> : JsonConverter<T>
{
public override T? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
if (reader.TokenType != JsonTokenType.StartObject)
throw new JsonException();
reader.Read();
if (reader.TokenType != JsonTokenType.PropertyName)
throw new JsonException();
string? propertyName = reader.GetString();
if (propertyName != "Name")
return default;
reader.Read();
if (reader.TokenType != JsonTokenType.String)
throw new JsonException();
string? name = reader.GetString() ?? throw new JsonException();
var inheritedTypes = Assembly.GetExecutingAssembly().GetTypes().Where(
t => typeof(T).IsAssignableFrom(t) && !t.IsAbstract && !t.IsInterface
);
var type = inheritedTypes.FirstOrDefault(t => t.Name == name);
if (type == null)
throw new JsonException();
reader.Read();
if (reader.TokenType != JsonTokenType.PropertyName)
throw new JsonException();
propertyName = reader.GetString();
if (propertyName != "Data")
throw new JsonException();
var data = JsonSerializer.Deserialize(ref reader, type, options);
if (data == null)
throw new JsonException();
reader.Read();
reader.Read();
return (T)data;
}

public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions options)
{
writer.WriteStartObject();
writer.WriteString("Name", value.GetType().Name);
writer.WritePropertyName("Data");
JsonSerializer.Serialize(writer, value, value.GetType(), options);
writer.WriteEndObject();
}
}
}

+ 28
- 2
LLama/LLamaContext.cs View File

@@ -204,7 +204,7 @@ namespace LLama
memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize);

// Wrap memory in a "state"
var state = new State(memory);
var state = new State(memory, actualSize);

// Set memory to zero, to prevent it being freed in finally block
memory = IntPtr.Zero;
@@ -422,9 +422,12 @@ namespace LLama
public class State
: SafeLLamaHandleBase
{
internal State(IntPtr memory)
private ulong _size;

internal State(IntPtr memory, ulong size)
: base(memory, true)
{
_size = size;
}

/// <inheritdoc />
@@ -433,6 +436,29 @@ namespace LLama
Marshal.FreeHGlobal(handle);
return true;
}

/// <summary>
/// Convert this state to a byte array
/// </summary>
/// <returns></returns>
public byte[] ToByteArray()
{
var bytes = new byte[_size];
Marshal.Copy(handle, bytes, 0, (int)_size);
return bytes;
}

/// <summary>
/// Load state from a byte array
/// </summary>
/// <param name="bytes"></param>
/// <returns></returns>
public static State FromByteArray(byte[] bytes)
{
var memory = Marshal.AllocHGlobal(bytes.Length);
Marshal.Copy(bytes, 0, memory, bytes.Length);
return new State(memory, (ulong)bytes.Length);
}
}
}
}

+ 32
- 3
LLama/LLamaExecutorBase.cs View File

@@ -315,6 +315,34 @@ namespace LLama
}
}

/// <summary>
/// Asynchronously runs a prompt through the model to compute KV cache without generating any new tokens.
/// It could reduce the latency of the first time response if the first input from the user is not immediate.
/// </summary>
/// <param name="prompt">Prompt to process</param>
/// <returns></returns>
public virtual async Task PrefillPromptAsync(string prompt)
{
var inferenceParams = new InferenceParams
{
MaxTokens = 0
};
var args = new InferStateArgs
{
Antiprompts = new List<string>(),
RemainedTokens = 0,
ReturnValue = false,
WaitForInput = true,
NeedToSaveSession = false
};

await PreprocessInputs(prompt, args);
// First run adds the prompt to the _embeds
await InferInternal(inferenceParams, args);
// Second run puts it through decode
await InferInternal(inferenceParams, args);
}

/// <summary>
/// State arguments that are used in single inference
/// </summary>
@@ -342,6 +370,7 @@ namespace LLama
public bool NeedToSaveSession { get; set; }
}

[JsonConverter(typeof(PolymorphicJSONConverter<ExecutorBaseState>))]
public class ExecutorBaseState
{
[JsonPropertyName("n_past")]
@@ -360,13 +389,13 @@ namespace LLama
public string? SessionFilePath { get; set; }

[JsonPropertyName("embd")]
public List<LLamaToken> Embeds { get; set; }
public LLamaToken[] Embeds { get; set; }

[JsonPropertyName("embd_inps")]
public List<LLamaToken> EmbedInps { get; set; }
public LLamaToken[] EmbedInps { get; set; }

[JsonPropertyName("session_tokens")]
public List<LLamaToken> SessionTokens { get; set; }
public LLamaToken[] SessionTokens { get; set; }

[JsonPropertyName("last_n_tokens")]
public LLamaToken[] LastTokens { get; set; }


+ 6
- 6
LLama/LLamaInstructExecutor.cs View File

@@ -49,17 +49,17 @@ namespace LLama
InstructExecutorState state = new()
{
ConsumedSessionCount = _n_session_consumed,
EmbedInps = _embed_inps,
EmbedInps = _embed_inps.ToArray(),
IsPromptRun = _is_prompt_run,
ConsumedTokensCount = _consumedTokensCount,
Embeds = _embeds,
Embeds = _embeds.ToArray(),
LastTokens = _last_n_tokens.ToArray(),
InputPrefixTokens = _inp_pfx,
InputSuffixTokens = _inp_sfx,
MatchingSessionTokensCount = _n_matching_session_tokens,
PastTokensCount = _pastTokensCount,
SessionFilePath = _pathSession,
SessionTokens = _session_tokens,
SessionTokens = _session_tokens.ToArray(),
LastTokensCapacity = _last_n_tokens.Capacity,
MirostatMu = MirostatMu
};
@@ -71,17 +71,17 @@ namespace LLama
if(data is InstructExecutorState state)
{
_n_session_consumed = state.ConsumedSessionCount;
_embed_inps = state.EmbedInps;
_embed_inps = state.EmbedInps.ToList();
_is_prompt_run = state.IsPromptRun;
_consumedTokensCount = state.ConsumedTokensCount;
_embeds = state.Embeds;
_embeds = state.Embeds.ToList();
_last_n_tokens = new FixedSizeQueue<LLamaToken>(state.LastTokensCapacity, state.LastTokens);
_inp_pfx = state.InputPrefixTokens;
_inp_sfx = state.InputSuffixTokens;
_n_matching_session_tokens = state.MatchingSessionTokensCount;
_pastTokensCount = state.PastTokensCount;
_pathSession = state.SessionFilePath;
_session_tokens = state.SessionTokens;
_session_tokens = state.SessionTokens.ToList();
}
else
{


+ 6
- 6
LLama/LLamaInteractExecutor.cs View File

@@ -39,15 +39,15 @@ namespace LLama
InteractiveExecutorState state = new()
{
ConsumedSessionCount = _n_session_consumed,
EmbedInps = _embed_inps,
EmbedInps = _embed_inps.ToArray(),
IsPromptRun = _is_prompt_run,
ConsumedTokensCount = _consumedTokensCount,
Embeds = _embeds,
Embeds = _embeds.ToArray(),
LastTokens = _last_n_tokens.ToArray(),
MatchingSessionTokensCount = _n_matching_session_tokens,
PastTokensCount = _pastTokensCount,
SessionFilePath = _pathSession,
SessionTokens = _session_tokens,
SessionTokens = _session_tokens.ToArray(),
LastTokensCapacity = _last_n_tokens.Capacity,
MirostatMu = MirostatMu
};
@@ -59,15 +59,15 @@ namespace LLama
if (data is InteractiveExecutorState state)
{
_n_session_consumed = state.ConsumedSessionCount;
_embed_inps = state.EmbedInps;
_embed_inps = state.EmbedInps.ToList();
_is_prompt_run = state.IsPromptRun;
_consumedTokensCount = state.ConsumedTokensCount;
_embeds = state.Embeds;
_embeds = state.Embeds.ToList();
_last_n_tokens = new FixedSizeQueue<LLamaToken>(state.LastTokensCapacity, state.LastTokens);
_n_matching_session_tokens = state.MatchingSessionTokensCount;
_pastTokensCount = state.PastTokensCount;
_pathSession = state.SessionFilePath;
_session_tokens = state.SessionTokens;
_session_tokens = state.SessionTokens.ToList();
}
else
throw new ArgumentException("Invalid state data type.");


+ 67
- 0
LLama/LLamaTransforms.cs View File

@@ -3,6 +3,7 @@ using LLama.Common;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Text.Json.Serialization;

namespace LLama
{
@@ -29,6 +30,12 @@ namespace LLama
private readonly string _unknownName;
private readonly bool _isInstructMode;

public string UserName => _userName;
public string AssistantName => _assistantName;
public string SystemName => _systemName;
public string UnknownName => _unknownName;
public bool IsInstructMode => _isInstructMode;

/// <summary>
///
/// </summary>
@@ -47,6 +54,12 @@ namespace LLama
_isInstructMode = isInstructMode;
}

/// <inheritdoc />
public IHistoryTransform Clone()
{
return new DefaultHistoryTransform(_userName, _assistantName, _systemName, _unknownName, _isInstructMode);
}

/// <inheritdoc />
public virtual string HistoryToText(ChatHistory history)
{
@@ -116,6 +129,12 @@ namespace LLama
{
return text.Trim();
}

/// <inheritdoc />
public ITextTransform Clone()
{
return new NaiveTextInputTransform();
}
}

/// <summary>
@@ -129,6 +148,12 @@ namespace LLama
{
return tokens;
}

/// <inheritdoc />
public ITextStreamTransform Clone()
{
return new EmptyTextOutputStreamTransform();
}
}

/// <summary>
@@ -140,6 +165,42 @@ namespace LLama
private readonly int _maxKeywordLength;
private readonly bool _removeAllMatchedTokens;

/// <summary>
/// Keywords that you want to remove from the response.
/// This property is used for JSON serialization.
/// </summary>
[JsonPropertyName("keywords")]
public HashSet<string> Keywords => _keywords;

/// <summary>
/// Maximum length of the keywords.
/// This property is used for JSON serialization.
/// </summary>
[JsonPropertyName("maxKeywordLength")]
public int MaxKeywordLength => _maxKeywordLength;

/// <summary>
/// If set to true, when getting a matched keyword, all the related tokens will be removed.
/// Otherwise only the part of keyword will be removed.
/// This property is used for JSON serialization.
/// </summary>
[JsonPropertyName("removeAllMatchedTokens")]
public bool RemoveAllMatchedTokens => _removeAllMatchedTokens;

/// <summary>
/// JSON constructor.
/// </summary>
[JsonConstructor]
public KeywordTextOutputStreamTransform(
HashSet<string> keywords,
int maxKeywordLength,
bool removeAllMatchedTokens)
{
_keywords = new(keywords);
_maxKeywordLength = maxKeywordLength;
_removeAllMatchedTokens = removeAllMatchedTokens;
}

/// <summary>
///
/// </summary>
@@ -157,6 +218,12 @@ namespace LLama
_removeAllMatchedTokens = removeAllMatchedTokens;
}

/// <inheritdoc />
public ITextStreamTransform Clone()
{
return new KeywordTextOutputStreamTransform(_keywords, _maxKeywordLength, _removeAllMatchedTokens);
}

/// <inheritdoc />
public async IAsyncEnumerable<string> TransformAsync(IAsyncEnumerable<string> tokens)
{


Loading…
Cancel
Save