From 9440f153da658dea9bd27937b8befe70fc5bea03 Mon Sep 17 00:00:00 2001 From: eublefar Date: Thu, 21 Mar 2024 12:14:15 +0100 Subject: [PATCH] Make process message method more flexible --- .../Examples/ChatSessionWithHistory.cs | 6 ++ .../Examples/ChatSessionWithRestart.cs | 36 ++++++---- LLama/ChatSession.cs | 70 ++++++++++++------- 3 files changed, 73 insertions(+), 39 deletions(-) diff --git a/LLama.Examples/Examples/ChatSessionWithHistory.cs b/LLama.Examples/Examples/ChatSessionWithHistory.cs index 6a84d2fd..31b6a771 100644 --- a/LLama.Examples/Examples/ChatSessionWithHistory.cs +++ b/LLama.Examples/Examples/ChatSessionWithHistory.cs @@ -48,6 +48,10 @@ public class ChatSessionWithHistory Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The chat session has started."); + Console.WriteLine("Type 'exit' to end the chat session."); + Console.WriteLine("Type 'save' to save the chat session to disk."); + Console.WriteLine("Type 'load' to load the chat session from disk."); + Console.WriteLine("Type 'regenerate' to regenerate the last response."); // show the prompt Console.ForegroundColor = ConsoleColor.Green; @@ -55,12 +59,14 @@ public class ChatSessionWithHistory while (userInput != "exit") { + // Save the chat state to disk if (userInput == "save") { session.SaveSession("Assets/chat-with-bob"); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("Session saved."); } + // Load the chat state from disk else if (userInput == "load") { session.LoadSession("Assets/chat-with-bob"); diff --git a/LLama.Examples/Examples/ChatSessionWithRestart.cs b/LLama.Examples/Examples/ChatSessionWithRestart.cs index 234bac3c..923f78f6 100644 --- a/LLama.Examples/Examples/ChatSessionWithRestart.cs +++ b/LLama.Examples/Examples/ChatSessionWithRestart.cs @@ -37,8 +37,11 @@ public class ChatSessionWithRestart }; Console.ForegroundColor = ConsoleColor.Yellow; - Console.WriteLine("The chat session has started. Write `save` to save session in memory." - + " Write `reset` to start from the last saved checkpoint"); + Console.WriteLine("The chat session has started. Starting point saved."); + Console.WriteLine("Type 'exit' to end the chat session."); + Console.WriteLine("Type 'save' to save chat session state in memory."); + Console.WriteLine("Type 'reset' to reset the chat session to its saved state."); + Console.WriteLine("Type 'answer for assistant' to add and process provided user and assistant messages."); // show the prompt Console.ForegroundColor = ConsoleColor.Green; @@ -46,6 +49,7 @@ public class ChatSessionWithRestart while (userInput != "exit") { + // Load the session state from the reset state if(userInput == "reset") { session.LoadSession(resetState); @@ -53,25 +57,33 @@ public class ChatSessionWithRestart Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("Session reset."); } + // Assign new reset state. else if (userInput == "save") { resetState = session.GetSessionState(); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("Session saved."); } - else if (userInput == "regenerate") + // Provide user and override assistant answer with your own. + else if (userInput == "answer for assistant") { Console.ForegroundColor = ConsoleColor.Yellow; - Console.WriteLine("Regenerating last response ..."); + Console.WriteLine("Provide user input: "); - await foreach ( - var text - in session.RegenerateAssistantMessageAsync( - inferenceParams)) - { - Console.ForegroundColor = ConsoleColor.White; - Console.Write(text); - } + Console.ForegroundColor = ConsoleColor.Green; + string userInputOverride = Console.ReadLine() ?? ""; + + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("Provide assistant input: "); + + Console.ForegroundColor = ConsoleColor.Green; + string assistantInputOverride = Console.ReadLine() ?? ""; + + await session.AddAndProcessUserMessage(userInputOverride); + await session.AddAndProcessAssistantMessage(assistantInputOverride); + + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("User and assistant messages processed. Provide next user message:"); } else { diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 6c9accdf..9620dc4f 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -262,33 +262,6 @@ public class ChatSession return this; } - - /// - /// Compute KV cache for the system message and add it to the chat history. - /// - /// - /// - public async Task ProcessSystemMessage(string content) - { - if (Executor is not StatefulExecutorBase statefulExecutor) - { - throw new InvalidOperationException("Executor must be a StatefulExecutorBase to support pre-processing of system messages."); - } - if (History.Messages.Count > 0) - { - throw new ArgumentException("Cannot add a system message after another message", nameof(content)); - } - foreach (var inputTransform in InputTransformPipeline) - { - content = inputTransform.Transform(content); - } - - await statefulExecutor.PrefillPromptAsync(content); - - History.AddMessage(AuthorRole.System, content); - return this; - } - /// /// Add a system message to the chat history. /// @@ -323,6 +296,49 @@ public class ChatSession return this; } + /// + /// Compute KV cache for the message and add it to the chat history. + /// + /// + /// + public async Task AddAndProcessMessage(ChatHistory.Message message) + { + if (Executor is not StatefulExecutorBase statefulExecutor) + { + throw new InvalidOperationException("Executor must be a StatefulExecutorBase to support pre-processing of system messages."); + } + AddMessage(message); + var content = message.Content; + if (message.AuthorRole != AuthorRole.Assistant) + { + foreach (var inputTransform in InputTransformPipeline) + { + content = inputTransform.Transform(content); + } + } + + await statefulExecutor.PrefillPromptAsync(content); + return this; + } + + /// + /// Compute KV cache for the system message and add it to the chat history. + /// + public Task AddAndProcessSystemMessage(string content) + => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.System, content)); + + /// + /// Compute KV cache for the user message and add it to the chat history. + /// + public Task AddAndProcessUserMessage(string content) + => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.User, content)); + + /// + /// Compute KV cache for the assistant message and add it to the chat history. + /// + public Task AddAndProcessAssistantMessage(string content) + => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.Assistant, content)); + /// /// Replace a user message with a new message and remove all messages after the new message. /// This is useful when the user wants to edit a message. And regenerate the response.