Chat session state managementtags/0.11.0
| @@ -8,6 +8,7 @@ public class ExampleRunner | |||
| { "Chat Session: History", ChatSessionWithHistory.Run }, | |||
| { "Chat Session: Role names", ChatSessionWithRoleName.Run }, | |||
| { "Chat Session: Role names stripped", ChatSessionStripRoleName.Run }, | |||
| { "Chat Session: Pre-processing and reset", ChatSessionWithRestart.Run }, | |||
| { "Chat Session: Coding Assistant", CodingAssistant.Run }, | |||
| { "Chat Session: Automatic conversation", TalkToYourself.Run }, | |||
| { "Chat Session: Chinese characters", ChatChineseGB2312.Run }, | |||
| @@ -48,6 +48,10 @@ public class ChatSessionWithHistory | |||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||
| Console.WriteLine("The chat session has started."); | |||
| Console.WriteLine("Type 'exit' to end the chat session."); | |||
| Console.WriteLine("Type 'save' to save the chat session to disk."); | |||
| Console.WriteLine("Type 'load' to load the chat session from disk."); | |||
| Console.WriteLine("Type 'regenerate' to regenerate the last response."); | |||
| // show the prompt | |||
| Console.ForegroundColor = ConsoleColor.Green; | |||
| @@ -55,12 +59,20 @@ public class ChatSessionWithHistory | |||
| while (userInput != "exit") | |||
| { | |||
| // Save the chat state to disk | |||
| if (userInput == "save") | |||
| { | |||
| session.SaveSession("Assets/chat-with-bob"); | |||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||
| Console.WriteLine("Session saved."); | |||
| } | |||
| // Load the chat state from disk | |||
| else if (userInput == "load") | |||
| { | |||
| session.LoadSession("Assets/chat-with-bob"); | |||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||
| Console.WriteLine("Session loaded."); | |||
| } | |||
| else if (userInput == "regenerate") | |||
| { | |||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||
| @@ -0,0 +1,107 @@ | |||
| using LLama.Common; | |||
| namespace LLama.Examples.Examples; | |||
| public class ChatSessionWithRestart | |||
| { | |||
| public static async Task Run() | |||
| { | |||
| string modelPath = UserSettings.GetModelPath(); | |||
| var parameters = new ModelParams(modelPath) | |||
| { | |||
| ContextSize = 1024, | |||
| Seed = 1337, | |||
| GpuLayerCount = 5 | |||
| }; | |||
| using var model = LLamaWeights.LoadFromFile(parameters); | |||
| using var context = model.CreateContext(parameters); | |||
| var executor = new InteractiveExecutor(context); | |||
| var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); | |||
| ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); | |||
| ChatSession prototypeSession = | |||
| await ChatSession.InitializeSessionFromHistoryAsync(executor, chatHistory); | |||
| prototypeSession.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform( | |||
| new string[] { "User:", "Assistant:" }, | |||
| redundancyLength: 8)); | |||
| var resetState = prototypeSession.GetSessionState(); | |||
| ChatSession session = new ChatSession(executor); | |||
| session.LoadSession(resetState); | |||
| InferenceParams inferenceParams = new InferenceParams() | |||
| { | |||
| Temperature = 0.9f, | |||
| AntiPrompts = new List<string> { "User:" } | |||
| }; | |||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||
| Console.WriteLine("The chat session has started. Starting point saved."); | |||
| Console.WriteLine("Type 'exit' to end the chat session."); | |||
| Console.WriteLine("Type 'save' to save chat session state in memory."); | |||
| Console.WriteLine("Type 'reset' to reset the chat session to its saved state."); | |||
| Console.WriteLine("Type 'answer for assistant' to add and process provided user and assistant messages."); | |||
| // show the prompt | |||
| Console.ForegroundColor = ConsoleColor.Green; | |||
| string userInput = Console.ReadLine() ?? ""; | |||
| while (userInput != "exit") | |||
| { | |||
| // Load the session state from the reset state | |||
| if(userInput == "reset") | |||
| { | |||
| session.LoadSession(resetState); | |||
| Console.WriteLine($"Reset to history:\n{session.HistoryTransform.HistoryToText(session.History)}"); | |||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||
| Console.WriteLine("Session reset."); | |||
| } | |||
| // Assign new reset state. | |||
| else if (userInput == "save") | |||
| { | |||
| resetState = session.GetSessionState(); | |||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||
| Console.WriteLine("Session saved."); | |||
| } | |||
| // Provide user and override assistant answer with your own. | |||
| else if (userInput == "answer for assistant") | |||
| { | |||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||
| Console.WriteLine("Provide user input: "); | |||
| Console.ForegroundColor = ConsoleColor.Green; | |||
| string userInputOverride = Console.ReadLine() ?? ""; | |||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||
| Console.WriteLine("Provide assistant input: "); | |||
| Console.ForegroundColor = ConsoleColor.Green; | |||
| string assistantInputOverride = Console.ReadLine() ?? ""; | |||
| await session.AddAndProcessUserMessage(userInputOverride); | |||
| await session.AddAndProcessAssistantMessage(assistantInputOverride); | |||
| Console.ForegroundColor = ConsoleColor.Yellow; | |||
| Console.WriteLine("User and assistant messages processed. Provide next user message:"); | |||
| } | |||
| else | |||
| { | |||
| await foreach ( | |||
| var text | |||
| in session.ChatAsync( | |||
| new ChatHistory.Message(AuthorRole.User, userInput), | |||
| inferenceParams)) | |||
| { | |||
| Console.ForegroundColor = ConsoleColor.White; | |||
| Console.Write(text); | |||
| } | |||
| } | |||
| Console.ForegroundColor = ConsoleColor.Green; | |||
| userInput = Console.ReadLine() ?? ""; | |||
| Console.ForegroundColor = ConsoleColor.White; | |||
| } | |||
| } | |||
| } | |||
| @@ -1,10 +1,12 @@ | |||
| using LLama.Common; | |||
| using System.Text.Json.Serialization; | |||
| namespace LLama.Abstractions | |||
| { | |||
| /// <summary> | |||
| /// Transform history to plain text and vice versa. | |||
| /// </summary> | |||
| [JsonConverter(typeof(PolymorphicJSONConverter<IHistoryTransform>))] | |||
| public interface IHistoryTransform | |||
| { | |||
| /// <summary> | |||
| @@ -21,5 +23,11 @@ namespace LLama.Abstractions | |||
| /// <param name="text">The chat history as plain text.</param> | |||
| /// <returns>The updated history.</returns> | |||
| ChatHistory TextToHistory(AuthorRole role, string text); | |||
| /// <summary> | |||
| /// Copy the transform. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| IHistoryTransform Clone(); | |||
| } | |||
| } | |||
| @@ -1,10 +1,13 @@ | |||
| using System.Collections.Generic; | |||
| using LLama.Common; | |||
| using System.Collections.Generic; | |||
| using System.Text.Json.Serialization; | |||
| namespace LLama.Abstractions | |||
| { | |||
| /// <summary> | |||
| /// Takes a stream of tokens and transforms them. | |||
| /// </summary> | |||
| [JsonConverter(typeof(PolymorphicJSONConverter<ITextStreamTransform>))] | |||
| public interface ITextStreamTransform | |||
| { | |||
| /// <summary> | |||
| @@ -13,5 +16,11 @@ namespace LLama.Abstractions | |||
| /// <param name="tokens"></param> | |||
| /// <returns></returns> | |||
| IAsyncEnumerable<string> TransformAsync(IAsyncEnumerable<string> tokens); | |||
| /// <summary> | |||
| /// Copy the transform. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| ITextStreamTransform Clone(); | |||
| } | |||
| } | |||
| @@ -1,4 +1,7 @@ | |||
| namespace LLama.Abstractions | |||
| using System.Text.Json.Serialization; | |||
| using LLama.Common; | |||
| namespace LLama.Abstractions | |||
| { | |||
| /// <summary> | |||
| /// An interface for text transformations. | |||
| @@ -9,6 +12,7 @@ | |||
| /// - Trimming | |||
| /// - etc. | |||
| /// </summary> | |||
| [JsonConverter(typeof(PolymorphicJSONConverter<ITextTransform>))] | |||
| public interface ITextTransform | |||
| { | |||
| /// <summary> | |||
| @@ -17,5 +21,11 @@ | |||
| /// <param name="text"></param> | |||
| /// <returns></returns> | |||
| string Transform(string text); | |||
| /// <summary> | |||
| /// Copy the transform. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| ITextTransform Clone(); | |||
| } | |||
| } | |||
| @@ -3,11 +3,14 @@ using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using System.Runtime.CompilerServices; | |||
| using System.Text.Json; | |||
| using System.Threading; | |||
| using System.Threading.Tasks; | |||
| using LLama.Abstractions; | |||
| using LLama.Common; | |||
| using static LLama.InteractiveExecutor; | |||
| using static LLama.LLamaContext; | |||
| using static LLama.StatefulExecutorBase; | |||
| namespace LLama; | |||
| @@ -16,9 +19,30 @@ namespace LLama; | |||
| /// </summary> | |||
| public class ChatSession | |||
| { | |||
| private const string _modelStateFilename = "ModelState.st"; | |||
| private const string _executorStateFilename = "ExecutorState.json"; | |||
| private const string _hsitoryFilename = "ChatHistory.json"; | |||
| /// <summary> | |||
| /// The filename for the serialized model state (KV cache, etc). | |||
| /// </summary> | |||
| public const string MODEL_STATE_FILENAME = "ModelState.st"; | |||
| /// <summary> | |||
| /// The filename for the serialized executor state. | |||
| /// </summary> | |||
| public const string EXECUTOR_STATE_FILENAME = "ExecutorState.json"; | |||
| /// <summary> | |||
| /// The filename for the serialized chat history. | |||
| /// </summary> | |||
| public const string HISTORY_STATE_FILENAME = "ChatHistory.json"; | |||
| /// <summary> | |||
| /// The filename for the serialized input transform pipeline. | |||
| /// </summary> | |||
| public const string INPUT_TRANSFORM_FILENAME = "InputTransform.json"; | |||
| /// <summary> | |||
| /// The filename for the serialized output transform. | |||
| /// </summary> | |||
| public const string OUTPUT_TRANSFORM_FILENAME = "OutputTransform.json"; | |||
| /// <summary> | |||
| /// The filename for the serialized history transform. | |||
| /// </summary> | |||
| public const string HISTORY_TRANSFORM_FILENAME = "HistoryTransform.json"; | |||
| /// <summary> | |||
| /// The executor for this session. | |||
| @@ -45,6 +69,24 @@ public class ChatSession | |||
| /// </summary> | |||
| public ITextStreamTransform OutputTransform = new LLamaTransforms.EmptyTextOutputStreamTransform(); | |||
| /// <summary> | |||
| /// Create a new chat session and preprocess history. | |||
| /// </summary> | |||
| /// <param name="executor">The executor for this session</param> | |||
| /// <param name="history">History for this session</param> | |||
| /// <returns></returns> | |||
| public static async Task<ChatSession> InitializeSessionFromHistoryAsync( | |||
| ILLamaExecutor executor, ChatHistory history) | |||
| { | |||
| if (executor is not StatefulExecutorBase statefulExecutor) | |||
| { | |||
| throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor)); | |||
| } | |||
| var session = new ChatSession(executor, history); | |||
| await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history)); | |||
| return session; | |||
| } | |||
| /// <summary> | |||
| /// Create a new chat session. | |||
| /// </summary> | |||
| @@ -112,56 +154,76 @@ public class ChatSession | |||
| /// <exception cref="ArgumentException"></exception> | |||
| public void SaveSession(string path) | |||
| { | |||
| if (string.IsNullOrWhiteSpace(path)) | |||
| GetSessionState().Save(path); | |||
| } | |||
| /// <summary> | |||
| /// Get the session state. | |||
| /// </summary> | |||
| /// <returns>SessionState object representing session state in-memory</returns> | |||
| public SessionState GetSessionState() | |||
| { | |||
| var executorState = ((StatefulExecutorBase)Executor).GetStateData(); | |||
| return new SessionState( | |||
| executorState.PastTokensCount > 0 | |||
| ? Executor.Context.GetState() : null, | |||
| executorState, | |||
| History, | |||
| InputTransformPipeline, | |||
| OutputTransform, | |||
| HistoryTransform); | |||
| } | |||
| /// <summary> | |||
| /// Load a session from a session state. | |||
| /// </summary> | |||
| /// <param name="state"></param> | |||
| /// <param name="loadTransforms">If true loads transforms saved in the session state.</param> | |||
| /// <returns></returns> | |||
| /// <exception cref="ArgumentException"></exception> | |||
| public void LoadSession(SessionState state, bool loadTransforms = true) | |||
| { | |||
| if (Executor is StatefulExecutorBase statefulExecutor) | |||
| { | |||
| throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); | |||
| if (state.ExecutorState is not null) | |||
| { | |||
| statefulExecutor.LoadState(state.ExecutorState); | |||
| } | |||
| } | |||
| if (Directory.Exists(path)) | |||
| if (state.ContextState is null) | |||
| { | |||
| Directory.Delete(path, recursive: true); | |||
| Executor.Context.NativeHandle.KvCacheClear(); | |||
| } | |||
| else | |||
| { | |||
| Executor.Context.LoadState(state.ContextState); | |||
| } | |||
| History = new ChatHistory(state.History); | |||
| if (loadTransforms) | |||
| { | |||
| InputTransformPipeline = state.InputTransformPipeline.Select(t => t.Clone()).ToList(); | |||
| OutputTransform = state.OutputTransform.Clone(); | |||
| HistoryTransform = state.HistoryTransform.Clone(); | |||
| } | |||
| Directory.CreateDirectory(path); | |||
| string modelStateFilePath = Path.Combine(path, _modelStateFilename); | |||
| Executor.Context.SaveState(modelStateFilePath); | |||
| string executorStateFilepath = Path.Combine(path, _executorStateFilename); | |||
| ((StatefulExecutorBase)Executor).SaveState(executorStateFilepath); | |||
| string historyFilepath = Path.Combine(path, _hsitoryFilename); | |||
| File.WriteAllText(historyFilepath, History.ToJson()); | |||
| } | |||
| /// <summary> | |||
| /// Load a session from a directory. | |||
| /// </summary> | |||
| /// <param name="path"></param> | |||
| /// <param name="loadTransforms">If true loads transforms saved in the session state.</param> | |||
| /// <returns></returns> | |||
| /// <exception cref="ArgumentException"></exception> | |||
| public void LoadSession(string path) | |||
| public void LoadSession(string path, bool loadTransforms = true) | |||
| { | |||
| if (string.IsNullOrWhiteSpace(path)) | |||
| var state = SessionState.Load(path); | |||
| // Handle non-polymorphic serialization of executor state | |||
| if (state.ExecutorState is null) | |||
| { | |||
| throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); | |||
| } | |||
| if (!Directory.Exists(path)) | |||
| { | |||
| throw new ArgumentException("Directory does not exist", nameof(path)); | |||
| var executorPath = Path.Combine(path, EXECUTOR_STATE_FILENAME); | |||
| ((StatefulExecutorBase) Executor).LoadState(filename: executorPath); | |||
| } | |||
| string modelStateFilePath = Path.Combine(path, _modelStateFilename); | |||
| Executor.Context.LoadState(modelStateFilePath); | |||
| string executorStateFilepath = Path.Combine(path, _executorStateFilename); | |||
| ((StatefulExecutorBase)Executor).LoadState(executorStateFilepath); | |||
| string historyFilepath = Path.Combine(path, _hsitoryFilename); | |||
| string historyJson = File.ReadAllText(historyFilepath); | |||
| History = ChatHistory.FromJson(historyJson) | |||
| ?? throw new ArgumentException("History file is invalid", nameof(path)); | |||
| LoadSession(state, loadTransforms); | |||
| } | |||
| /// <summary> | |||
| @@ -238,6 +300,49 @@ public class ChatSession | |||
| return this; | |||
| } | |||
| /// <summary> | |||
| /// Compute KV cache for the message and add it to the chat history. | |||
| /// </summary> | |||
| /// <param name="message"></param> | |||
| /// <returns></returns> | |||
| public async Task<ChatSession> AddAndProcessMessage(ChatHistory.Message message) | |||
| { | |||
| if (Executor is not StatefulExecutorBase statefulExecutor) | |||
| { | |||
| throw new InvalidOperationException("Executor must be a StatefulExecutorBase to support pre-processing of system messages."); | |||
| } | |||
| AddMessage(message); | |||
| var content = message.Content; | |||
| if (message.AuthorRole != AuthorRole.Assistant) | |||
| { | |||
| foreach (var inputTransform in InputTransformPipeline) | |||
| { | |||
| content = inputTransform.Transform(content); | |||
| } | |||
| } | |||
| await statefulExecutor.PrefillPromptAsync(content); | |||
| return this; | |||
| } | |||
| /// <summary> | |||
| /// Compute KV cache for the system message and add it to the chat history. | |||
| /// </summary> | |||
| public Task<ChatSession> AddAndProcessSystemMessage(string content) | |||
| => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.System, content)); | |||
| /// <summary> | |||
| /// Compute KV cache for the user message and add it to the chat history. | |||
| /// </summary> | |||
| public Task<ChatSession> AddAndProcessUserMessage(string content) | |||
| => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.User, content)); | |||
| /// <summary> | |||
| /// Compute KV cache for the assistant message and add it to the chat history. | |||
| /// </summary> | |||
| public Task<ChatSession> AddAndProcessAssistantMessage(string content) | |||
| => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.Assistant, content)); | |||
| /// <summary> | |||
| /// Replace a user message with a new message and remove all messages after the new message. | |||
| /// This is useful when the user wants to edit a message. And regenerate the response. | |||
| @@ -494,3 +599,185 @@ public class ChatSession | |||
| } | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// The state of a chat session in-memory. | |||
| /// </summary> | |||
| public record SessionState | |||
| { | |||
| /// <summary> | |||
| /// Saved executor state for the session in JSON format. | |||
| /// </summary> | |||
| public ExecutorBaseState? ExecutorState { get; set; } | |||
| /// <summary> | |||
| /// Saved context state (KV cache) for the session. | |||
| /// </summary> | |||
| public State? ContextState { get; set; } | |||
| /// <summary> | |||
| /// The input transform pipeline used in this session. | |||
| /// </summary> | |||
| public ITextTransform[] InputTransformPipeline { get; set; } = Array.Empty<ITextTransform>(); | |||
| /// <summary> | |||
| /// The output transform used in this session. | |||
| /// </summary> | |||
| public ITextStreamTransform OutputTransform { get; set; } = new LLamaTransforms.EmptyTextOutputStreamTransform(); | |||
| /// <summary> | |||
| /// The history transform used in this session. | |||
| /// </summary> | |||
| public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform(); | |||
| /// <summary> | |||
| /// The the chat history messages for this session. | |||
| /// </summary> | |||
| public ChatHistory.Message[] History { get; set; } = Array.Empty<ChatHistory.Message>(); | |||
| /// <summary> | |||
| /// Create a new session state. | |||
| /// </summary> | |||
| /// <param name="contextState"></param> | |||
| /// <param name="executorState"></param> | |||
| /// <param name="history"></param> | |||
| /// <param name="inputTransformPipeline"></param> | |||
| /// <param name="outputTransform"></param> | |||
| /// <param name="historyTransform"></param> | |||
| public SessionState( | |||
| State? contextState, ExecutorBaseState executorState, | |||
| ChatHistory history, List<ITextTransform> inputTransformPipeline, | |||
| ITextStreamTransform outputTransform, IHistoryTransform historyTransform) | |||
| { | |||
| ContextState = contextState; | |||
| ExecutorState = executorState; | |||
| History = history.Messages.ToArray(); | |||
| InputTransformPipeline = inputTransformPipeline.Select(t => t.Clone()).ToArray(); | |||
| OutputTransform = outputTransform.Clone(); | |||
| HistoryTransform = historyTransform.Clone(); | |||
| } | |||
| /// <summary> | |||
| /// Save the session state to folder. | |||
| /// </summary> | |||
| /// <param name="path"></param> | |||
| public void Save(string path) | |||
| { | |||
| if (string.IsNullOrWhiteSpace(path)) | |||
| { | |||
| throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); | |||
| } | |||
| if (Directory.Exists(path)) | |||
| { | |||
| Directory.Delete(path, recursive: true); | |||
| } | |||
| Directory.CreateDirectory(path); | |||
| string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME); | |||
| var bytes = ContextState?.ToByteArray(); | |||
| if (bytes is not null) | |||
| { | |||
| File.WriteAllBytes(modelStateFilePath, bytes); | |||
| } | |||
| string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME); | |||
| File.WriteAllText(executorStateFilepath, JsonSerializer.Serialize(ExecutorState)); | |||
| string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME); | |||
| File.WriteAllText(historyFilepath, new ChatHistory(History).ToJson()); | |||
| string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME); | |||
| File.WriteAllText(inputTransformFilepath, JsonSerializer.Serialize(InputTransformPipeline)); | |||
| string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME); | |||
| File.WriteAllText(outputTransformFilepath, JsonSerializer.Serialize(OutputTransform)); | |||
| string historyTransformFilepath = Path.Combine(path, ChatSession.HISTORY_TRANSFORM_FILENAME); | |||
| File.WriteAllText(historyTransformFilepath, JsonSerializer.Serialize(HistoryTransform)); | |||
| } | |||
| /// <summary> | |||
| /// Load the session state from folder. | |||
| /// </summary> | |||
| /// <param name="path"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="ArgumentException">Throws when session state is incorrect</exception> | |||
| public static SessionState Load(string path) | |||
| { | |||
| if (string.IsNullOrWhiteSpace(path)) | |||
| { | |||
| throw new ArgumentException("Path cannot be null or whitespace", nameof(path)); | |||
| } | |||
| if (!Directory.Exists(path)) | |||
| { | |||
| throw new ArgumentException("Directory does not exist", nameof(path)); | |||
| } | |||
| string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME); | |||
| var contextState = File.Exists(modelStateFilePath) ? | |||
| State.FromByteArray(File.ReadAllBytes(modelStateFilePath)) | |||
| : null; | |||
| string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME); | |||
| var executorState = JsonSerializer.Deserialize<ExecutorBaseState>(File.ReadAllText(executorStateFilepath)); | |||
| string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME); | |||
| string historyJson = File.ReadAllText(historyFilepath); | |||
| var history = ChatHistory.FromJson(historyJson) | |||
| ?? throw new ArgumentException("History file is invalid", nameof(path)); | |||
| string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME); | |||
| ITextTransform[] inputTransforms; | |||
| try | |||
| { | |||
| inputTransforms = File.Exists(inputTransformFilepath) ? | |||
| (JsonSerializer.Deserialize<ITextTransform[]>(File.ReadAllText(inputTransformFilepath)) | |||
| ?? throw new ArgumentException("Input transform file is invalid", nameof(path))) | |||
| : Array.Empty<ITextTransform>(); | |||
| } | |||
| catch (JsonException) | |||
| { | |||
| throw new ArgumentException("Input transform file is invalid", nameof(path)); | |||
| } | |||
| string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME); | |||
| ITextStreamTransform outputTransform; | |||
| try | |||
| { | |||
| outputTransform = File.Exists(outputTransformFilepath) ? | |||
| (JsonSerializer.Deserialize<ITextStreamTransform>(File.ReadAllText(outputTransformFilepath)) | |||
| ?? throw new ArgumentException("Output transform file is invalid", nameof(path))) | |||
| : new LLamaTransforms.EmptyTextOutputStreamTransform(); | |||
| } | |||
| catch (JsonException) | |||
| { | |||
| throw new ArgumentException("Output transform file is invalid", nameof(path)); | |||
| } | |||
| string historyTransformFilepath = Path.Combine(path, ChatSession.HISTORY_TRANSFORM_FILENAME); | |||
| IHistoryTransform historyTransform; | |||
| try | |||
| { | |||
| historyTransform = File.Exists(historyTransformFilepath) ? | |||
| (JsonSerializer.Deserialize<IHistoryTransform>(File.ReadAllText(historyTransformFilepath)) | |||
| ?? throw new ArgumentException("History transform file is invalid", nameof(path))) | |||
| : new LLamaTransforms.DefaultHistoryTransform(); | |||
| } | |||
| catch (JsonException) | |||
| { | |||
| throw new ArgumentException("History transform file is invalid", nameof(path)); | |||
| } | |||
| return new SessionState( | |||
| contextState, | |||
| executorState, | |||
| history, | |||
| inputTransforms.ToList(), | |||
| outputTransform, | |||
| historyTransform); | |||
| } | |||
| } | |||
| @@ -1,4 +1,5 @@ | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text.Json; | |||
| using System.Text.Json.Serialization; | |||
| @@ -80,6 +81,15 @@ namespace LLama.Common | |||
| [JsonConstructor] | |||
| public ChatHistory() { } | |||
| /// <summary> | |||
| /// Create a new instance of the chat history from array of messages | |||
| /// </summary> | |||
| /// <param name="messageHistory"></param> | |||
| public ChatHistory(Message[] messageHistory) | |||
| { | |||
| this.Messages = messageHistory.ToList(); | |||
| } | |||
| /// <summary> | |||
| /// Add a message to the chat history | |||
| /// </summary> | |||
| @@ -0,0 +1,57 @@ | |||
| using LLama.Abstractions; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Reflection; | |||
| using System.Text; | |||
| using System.Text.Json; | |||
| using System.Text.Json.Serialization; | |||
| namespace LLama.Common | |||
| { | |||
| internal class PolymorphicJSONConverter<T> : JsonConverter<T> | |||
| { | |||
| public override T? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) | |||
| { | |||
| if (reader.TokenType != JsonTokenType.StartObject) | |||
| throw new JsonException(); | |||
| reader.Read(); | |||
| if (reader.TokenType != JsonTokenType.PropertyName) | |||
| throw new JsonException(); | |||
| string? propertyName = reader.GetString(); | |||
| if (propertyName != "Name") | |||
| return default; | |||
| reader.Read(); | |||
| if (reader.TokenType != JsonTokenType.String) | |||
| throw new JsonException(); | |||
| string? name = reader.GetString() ?? throw new JsonException(); | |||
| var inheritedTypes = Assembly.GetExecutingAssembly().GetTypes().Where( | |||
| t => typeof(T).IsAssignableFrom(t) && !t.IsAbstract && !t.IsInterface | |||
| ); | |||
| var type = inheritedTypes.FirstOrDefault(t => t.Name == name); | |||
| if (type == null) | |||
| throw new JsonException(); | |||
| reader.Read(); | |||
| if (reader.TokenType != JsonTokenType.PropertyName) | |||
| throw new JsonException(); | |||
| propertyName = reader.GetString(); | |||
| if (propertyName != "Data") | |||
| throw new JsonException(); | |||
| var data = JsonSerializer.Deserialize(ref reader, type, options); | |||
| if (data == null) | |||
| throw new JsonException(); | |||
| reader.Read(); | |||
| reader.Read(); | |||
| return (T)data; | |||
| } | |||
| public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions options) | |||
| { | |||
| writer.WriteStartObject(); | |||
| writer.WriteString("Name", value.GetType().Name); | |||
| writer.WritePropertyName("Data"); | |||
| JsonSerializer.Serialize(writer, value, value.GetType(), options); | |||
| writer.WriteEndObject(); | |||
| } | |||
| } | |||
| } | |||
| @@ -204,7 +204,7 @@ namespace LLama | |||
| memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize); | |||
| // Wrap memory in a "state" | |||
| var state = new State(memory); | |||
| var state = new State(memory, actualSize); | |||
| // Set memory to zero, to prevent it being freed in finally block | |||
| memory = IntPtr.Zero; | |||
| @@ -422,9 +422,12 @@ namespace LLama | |||
| public class State | |||
| : SafeLLamaHandleBase | |||
| { | |||
| internal State(IntPtr memory) | |||
| private ulong _size; | |||
| internal State(IntPtr memory, ulong size) | |||
| : base(memory, true) | |||
| { | |||
| _size = size; | |||
| } | |||
| /// <inheritdoc /> | |||
| @@ -433,6 +436,29 @@ namespace LLama | |||
| Marshal.FreeHGlobal(handle); | |||
| return true; | |||
| } | |||
| /// <summary> | |||
| /// Convert this state to a byte array | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public byte[] ToByteArray() | |||
| { | |||
| var bytes = new byte[_size]; | |||
| Marshal.Copy(handle, bytes, 0, (int)_size); | |||
| return bytes; | |||
| } | |||
| /// <summary> | |||
| /// Load state from a byte array | |||
| /// </summary> | |||
| /// <param name="bytes"></param> | |||
| /// <returns></returns> | |||
| public static State FromByteArray(byte[] bytes) | |||
| { | |||
| var memory = Marshal.AllocHGlobal(bytes.Length); | |||
| Marshal.Copy(bytes, 0, memory, bytes.Length); | |||
| return new State(memory, (ulong)bytes.Length); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -315,6 +315,34 @@ namespace LLama | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Asynchronously runs a prompt through the model to compute KV cache without generating any new tokens. | |||
| /// It could reduce the latency of the first time response if the first input from the user is not immediate. | |||
| /// </summary> | |||
| /// <param name="prompt">Prompt to process</param> | |||
| /// <returns></returns> | |||
| public virtual async Task PrefillPromptAsync(string prompt) | |||
| { | |||
| var inferenceParams = new InferenceParams | |||
| { | |||
| MaxTokens = 0 | |||
| }; | |||
| var args = new InferStateArgs | |||
| { | |||
| Antiprompts = new List<string>(), | |||
| RemainedTokens = 0, | |||
| ReturnValue = false, | |||
| WaitForInput = true, | |||
| NeedToSaveSession = false | |||
| }; | |||
| await PreprocessInputs(prompt, args); | |||
| // First run adds the prompt to the _embeds | |||
| await InferInternal(inferenceParams, args); | |||
| // Second run puts it through decode | |||
| await InferInternal(inferenceParams, args); | |||
| } | |||
| /// <summary> | |||
| /// State arguments that are used in single inference | |||
| /// </summary> | |||
| @@ -342,6 +370,7 @@ namespace LLama | |||
| public bool NeedToSaveSession { get; set; } | |||
| } | |||
| [JsonConverter(typeof(PolymorphicJSONConverter<ExecutorBaseState>))] | |||
| public class ExecutorBaseState | |||
| { | |||
| [JsonPropertyName("n_past")] | |||
| @@ -360,13 +389,13 @@ namespace LLama | |||
| public string? SessionFilePath { get; set; } | |||
| [JsonPropertyName("embd")] | |||
| public List<LLamaToken> Embeds { get; set; } | |||
| public LLamaToken[] Embeds { get; set; } | |||
| [JsonPropertyName("embd_inps")] | |||
| public List<LLamaToken> EmbedInps { get; set; } | |||
| public LLamaToken[] EmbedInps { get; set; } | |||
| [JsonPropertyName("session_tokens")] | |||
| public List<LLamaToken> SessionTokens { get; set; } | |||
| public LLamaToken[] SessionTokens { get; set; } | |||
| [JsonPropertyName("last_n_tokens")] | |||
| public LLamaToken[] LastTokens { get; set; } | |||
| @@ -49,17 +49,17 @@ namespace LLama | |||
| InstructExecutorState state = new() | |||
| { | |||
| ConsumedSessionCount = _n_session_consumed, | |||
| EmbedInps = _embed_inps, | |||
| EmbedInps = _embed_inps.ToArray(), | |||
| IsPromptRun = _is_prompt_run, | |||
| ConsumedTokensCount = _consumedTokensCount, | |||
| Embeds = _embeds, | |||
| Embeds = _embeds.ToArray(), | |||
| LastTokens = _last_n_tokens.ToArray(), | |||
| InputPrefixTokens = _inp_pfx, | |||
| InputSuffixTokens = _inp_sfx, | |||
| MatchingSessionTokensCount = _n_matching_session_tokens, | |||
| PastTokensCount = _pastTokensCount, | |||
| SessionFilePath = _pathSession, | |||
| SessionTokens = _session_tokens, | |||
| SessionTokens = _session_tokens.ToArray(), | |||
| LastTokensCapacity = _last_n_tokens.Capacity, | |||
| MirostatMu = MirostatMu | |||
| }; | |||
| @@ -71,17 +71,17 @@ namespace LLama | |||
| if(data is InstructExecutorState state) | |||
| { | |||
| _n_session_consumed = state.ConsumedSessionCount; | |||
| _embed_inps = state.EmbedInps; | |||
| _embed_inps = state.EmbedInps.ToList(); | |||
| _is_prompt_run = state.IsPromptRun; | |||
| _consumedTokensCount = state.ConsumedTokensCount; | |||
| _embeds = state.Embeds; | |||
| _embeds = state.Embeds.ToList(); | |||
| _last_n_tokens = new FixedSizeQueue<LLamaToken>(state.LastTokensCapacity, state.LastTokens); | |||
| _inp_pfx = state.InputPrefixTokens; | |||
| _inp_sfx = state.InputSuffixTokens; | |||
| _n_matching_session_tokens = state.MatchingSessionTokensCount; | |||
| _pastTokensCount = state.PastTokensCount; | |||
| _pathSession = state.SessionFilePath; | |||
| _session_tokens = state.SessionTokens; | |||
| _session_tokens = state.SessionTokens.ToList(); | |||
| } | |||
| else | |||
| { | |||
| @@ -39,15 +39,15 @@ namespace LLama | |||
| InteractiveExecutorState state = new() | |||
| { | |||
| ConsumedSessionCount = _n_session_consumed, | |||
| EmbedInps = _embed_inps, | |||
| EmbedInps = _embed_inps.ToArray(), | |||
| IsPromptRun = _is_prompt_run, | |||
| ConsumedTokensCount = _consumedTokensCount, | |||
| Embeds = _embeds, | |||
| Embeds = _embeds.ToArray(), | |||
| LastTokens = _last_n_tokens.ToArray(), | |||
| MatchingSessionTokensCount = _n_matching_session_tokens, | |||
| PastTokensCount = _pastTokensCount, | |||
| SessionFilePath = _pathSession, | |||
| SessionTokens = _session_tokens, | |||
| SessionTokens = _session_tokens.ToArray(), | |||
| LastTokensCapacity = _last_n_tokens.Capacity, | |||
| MirostatMu = MirostatMu | |||
| }; | |||
| @@ -59,15 +59,15 @@ namespace LLama | |||
| if (data is InteractiveExecutorState state) | |||
| { | |||
| _n_session_consumed = state.ConsumedSessionCount; | |||
| _embed_inps = state.EmbedInps; | |||
| _embed_inps = state.EmbedInps.ToList(); | |||
| _is_prompt_run = state.IsPromptRun; | |||
| _consumedTokensCount = state.ConsumedTokensCount; | |||
| _embeds = state.Embeds; | |||
| _embeds = state.Embeds.ToList(); | |||
| _last_n_tokens = new FixedSizeQueue<LLamaToken>(state.LastTokensCapacity, state.LastTokens); | |||
| _n_matching_session_tokens = state.MatchingSessionTokensCount; | |||
| _pastTokensCount = state.PastTokensCount; | |||
| _pathSession = state.SessionFilePath; | |||
| _session_tokens = state.SessionTokens; | |||
| _session_tokens = state.SessionTokens.ToList(); | |||
| } | |||
| else | |||
| throw new ArgumentException("Invalid state data type."); | |||
| @@ -3,6 +3,7 @@ using LLama.Common; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using System.Text.Json.Serialization; | |||
| namespace LLama | |||
| { | |||
| @@ -29,6 +30,12 @@ namespace LLama | |||
| private readonly string _unknownName; | |||
| private readonly bool _isInstructMode; | |||
| public string UserName => _userName; | |||
| public string AssistantName => _assistantName; | |||
| public string SystemName => _systemName; | |||
| public string UnknownName => _unknownName; | |||
| public bool IsInstructMode => _isInstructMode; | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| @@ -47,6 +54,12 @@ namespace LLama | |||
| _isInstructMode = isInstructMode; | |||
| } | |||
| /// <inheritdoc /> | |||
| public IHistoryTransform Clone() | |||
| { | |||
| return new DefaultHistoryTransform(_userName, _assistantName, _systemName, _unknownName, _isInstructMode); | |||
| } | |||
| /// <inheritdoc /> | |||
| public virtual string HistoryToText(ChatHistory history) | |||
| { | |||
| @@ -116,6 +129,12 @@ namespace LLama | |||
| { | |||
| return text.Trim(); | |||
| } | |||
| /// <inheritdoc /> | |||
| public ITextTransform Clone() | |||
| { | |||
| return new NaiveTextInputTransform(); | |||
| } | |||
| } | |||
| /// <summary> | |||
| @@ -129,6 +148,12 @@ namespace LLama | |||
| { | |||
| return tokens; | |||
| } | |||
| /// <inheritdoc /> | |||
| public ITextStreamTransform Clone() | |||
| { | |||
| return new EmptyTextOutputStreamTransform(); | |||
| } | |||
| } | |||
| /// <summary> | |||
| @@ -140,6 +165,42 @@ namespace LLama | |||
| private readonly int _maxKeywordLength; | |||
| private readonly bool _removeAllMatchedTokens; | |||
| /// <summary> | |||
| /// Keywords that you want to remove from the response. | |||
| /// This property is used for JSON serialization. | |||
| /// </summary> | |||
| [JsonPropertyName("keywords")] | |||
| public HashSet<string> Keywords => _keywords; | |||
| /// <summary> | |||
| /// Maximum length of the keywords. | |||
| /// This property is used for JSON serialization. | |||
| /// </summary> | |||
| [JsonPropertyName("maxKeywordLength")] | |||
| public int MaxKeywordLength => _maxKeywordLength; | |||
| /// <summary> | |||
| /// If set to true, when getting a matched keyword, all the related tokens will be removed. | |||
| /// Otherwise only the part of keyword will be removed. | |||
| /// This property is used for JSON serialization. | |||
| /// </summary> | |||
| [JsonPropertyName("removeAllMatchedTokens")] | |||
| public bool RemoveAllMatchedTokens => _removeAllMatchedTokens; | |||
| /// <summary> | |||
| /// JSON constructor. | |||
| /// </summary> | |||
| [JsonConstructor] | |||
| public KeywordTextOutputStreamTransform( | |||
| HashSet<string> keywords, | |||
| int maxKeywordLength, | |||
| bool removeAllMatchedTokens) | |||
| { | |||
| _keywords = new(keywords); | |||
| _maxKeywordLength = maxKeywordLength; | |||
| _removeAllMatchedTokens = removeAllMatchedTokens; | |||
| } | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| @@ -157,6 +218,12 @@ namespace LLama | |||
| _removeAllMatchedTokens = removeAllMatchedTokens; | |||
| } | |||
| /// <inheritdoc /> | |||
| public ITextStreamTransform Clone() | |||
| { | |||
| return new KeywordTextOutputStreamTransform(_keywords, _maxKeywordLength, _removeAllMatchedTokens); | |||
| } | |||
| /// <inheritdoc /> | |||
| public async IAsyncEnumerable<string> TransformAsync(IAsyncEnumerable<string> tokens) | |||
| { | |||