diff --git a/LLama.WebAPI/Controllers/ChatController.cs b/LLama.WebAPI/Controllers/ChatController.cs index cf64a9e5..001a3224 100644 --- a/LLama.WebAPI/Controllers/ChatController.cs +++ b/LLama.WebAPI/Controllers/ChatController.cs @@ -1,3 +1,4 @@ +using LLama.Common; using LLama.WebAPI.Models; using LLama.WebAPI.Services; using Microsoft.AspNetCore.Mvc; @@ -9,20 +10,44 @@ namespace LLama.WebAPI.Controllers [Route("[controller]")] public class ChatController : ControllerBase { - private readonly ChatService _service; private readonly ILogger _logger; - public ChatController(ILogger logger, - ChatService service) + public ChatController(ILogger logger) { _logger = logger; - _service = service; } [HttpPost("Send")] - public string SendMessage([FromBody] SendMessageInput input) + public string SendMessage([FromBody] SendMessageInput input, [FromServices] StatefulChatService _service) { return _service.Send(input); } + + [HttpPost("Send/Stream")] + public async Task SendMessageStream([FromBody] SendMessageInput input, [FromServices] StatefulChatService _service, CancellationToken cancellationToken) + { + + Response.ContentType = "text/event-stream"; + + await foreach (var r in _service.SendStream(input)) + { + await Response.WriteAsync("data:" + r + "\n\n", cancellationToken); + await Response.Body.FlushAsync(cancellationToken); + } + + await Response.CompleteAsync(); + } + + [HttpPost("History")] + public async Task SendHistory([FromBody] HistoryInput input, [FromServices] StatelessChatService _service) + { + var history = new ChatHistory(); + + var messages = input.Messages.Select(m => new ChatHistory.Message(Enum.Parse(m.Role), m.Content)); + + history.Messages.AddRange(messages); + + return await _service.SendAsync(history); + } } } \ No newline at end of file diff --git a/LLama.WebAPI/LLama.WebAPI.csproj b/LLama.WebAPI/LLama.WebAPI.csproj index 9227d57f..2b14d8a4 100644 --- a/LLama.WebAPI/LLama.WebAPI.csproj +++ b/LLama.WebAPI/LLama.WebAPI.csproj @@ -7,6 +7,7 @@ + diff --git a/LLama.WebAPI/Models/SendMessageInput.cs b/LLama.WebAPI/Models/SendMessageInput.cs index 741b631f..11152ff8 100644 --- a/LLama.WebAPI/Models/SendMessageInput.cs +++ b/LLama.WebAPI/Models/SendMessageInput.cs @@ -4,3 +4,13 @@ public class SendMessageInput { public string Text { get; set; } } + +public class HistoryInput +{ + public List Messages { get; set; } + public class HistoryItem + { + public string Role { get; set; } + public string Content { get; set; } + } +} \ No newline at end of file diff --git a/LLama.WebAPI/Program.cs b/LLama.WebAPI/Program.cs index 33e8bf81..3f2de200 100644 --- a/LLama.WebAPI/Program.cs +++ b/LLama.WebAPI/Program.cs @@ -9,7 +9,8 @@ builder.Services.AddControllers(); builder.Services.AddEndpointsApiExplorer(); builder.Services.AddSwaggerGen(); -builder.Services.AddSingleton(); +builder.Services.AddSingleton(); +builder.Services.AddScoped(); var app = builder.Build(); diff --git a/LLama.WebAPI/Services/ChatService.cs b/LLama.WebAPI/Services/ChatService.cs deleted file mode 100644 index e457e3c2..00000000 --- a/LLama.WebAPI/Services/ChatService.cs +++ /dev/null @@ -1,34 +0,0 @@ -using LLama.OldVersion; -using LLama.WebAPI.Models; - -namespace LLama.WebAPI.Services; - -public class ChatService -{ - private readonly ChatSession _session; - - public ChatService() - { - LLamaModel model = new(new LLamaParams(model: @"ggml-model-q4_0.bin", n_ctx: 512, interactive: true, repeat_penalty: 1.0f, verbose_prompt: false)); - _session = new ChatSession(model) - .WithPromptFile(@"Assets\chat-with-bob.txt") - .WithAntiprompt(new string[] { "User:" }); - } - - public string Send(SendMessageInput input) - { - Console.ForegroundColor = ConsoleColor.Green; - Console.WriteLine(input.Text); - - Console.ForegroundColor = ConsoleColor.White; - var outputs = _session.Chat(input.Text); - var result = ""; - foreach (var output in outputs) - { - Console.Write(output); - result += output; - } - - return result; - } -} diff --git a/LLama.WebAPI/Services/StatefulChatService.cs b/LLama.WebAPI/Services/StatefulChatService.cs new file mode 100644 index 00000000..ab89b517 --- /dev/null +++ b/LLama.WebAPI/Services/StatefulChatService.cs @@ -0,0 +1,82 @@ + +using LLama.WebAPI.Models; +using Microsoft; +using System.Runtime.CompilerServices; + +namespace LLama.WebAPI.Services; + +public class StatefulChatService : IDisposable +{ + private readonly ChatSession _session; + private readonly LLamaModel _model; + private bool _continue = false; + + private const string SystemPrompt = "Transcript of a dialog, where the User interacts with an Assistant. Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.\n\n" + + "User: "; + + public StatefulChatService(IConfiguration configuration) + { + _model = new LLamaModel(new Common.ModelParams(configuration["ModelPath"], contextSize: 512)); + _session = new ChatSession(new InteractiveExecutor(_model)); + } + + public void Dispose() + { + _model?.Dispose(); + } + + public string Send(SendMessageInput input) + { + var userInput = input.Text; + if (!_continue) + { + userInput = SystemPrompt + userInput; + Console.Write(SystemPrompt); + _continue = true; + } + + Console.ForegroundColor = ConsoleColor.Green; + Console.Write(input.Text); + + Console.ForegroundColor = ConsoleColor.White; + var outputs = _session.Chat(userInput, new Common.InferenceParams() + { + RepeatPenalty = 1.0f, + AntiPrompts = new string[] { "User:" }, + }); + var result = ""; + foreach (var output in outputs) + { + Console.Write(output); + result += output; + } + + return result; + } + + public async IAsyncEnumerable SendStream(SendMessageInput input) + { + var userInput = input.Text; + if (!_continue) + { + userInput = SystemPrompt + userInput; + Console.Write(SystemPrompt); + _continue = true; + } + + Console.ForegroundColor = ConsoleColor.Green; + Console.Write(input.Text); + + Console.ForegroundColor = ConsoleColor.White; + var outputs = _session.ChatAsync(userInput, new Common.InferenceParams() + { + RepeatPenalty = 1.0f, + AntiPrompts = new string[] { "User:" }, + }); + await foreach (var output in outputs) + { + Console.Write(output); + yield return output; + } + } +} diff --git a/LLama.WebAPI/Services/StatelessChatService.cs b/LLama.WebAPI/Services/StatelessChatService.cs new file mode 100644 index 00000000..c1356646 --- /dev/null +++ b/LLama.WebAPI/Services/StatelessChatService.cs @@ -0,0 +1,48 @@ +using LLama.Common; +using Microsoft.AspNetCore.Http; +using System.Text; +using static LLama.LLamaTransforms; + +namespace LLama.WebAPI.Services +{ + public class StatelessChatService + { + private readonly LLamaModel _model; + private readonly ChatSession _session; + + public StatelessChatService(IConfiguration configuration) + { + _model = new LLamaModel(new ModelParams(configuration["ModelPath"], contextSize: 512)); + // TODO: replace with a stateless executor + _session = new ChatSession(new InteractiveExecutor(_model)) + .WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Assistant:" }, redundancyLength: 8)) + .WithHistoryTransform(new HistoryTransform()); + } + + public async Task SendAsync(ChatHistory history) + { + var result = _session.ChatAsync(history, new InferenceParams() + { + AntiPrompts = new string[] { "User:" }, + }); + + var sb = new StringBuilder(); + await foreach (var r in result) + { + Console.Write(r); + sb.Append(r); + } + + return sb.ToString(); + + } + } + public class HistoryTransform : DefaultHistoryTransform + { + public override string HistoryToText(ChatHistory history) + { + return base.HistoryToText(history) + "\n Assistant:"; + } + + } +} diff --git a/LLama.WebAPI/appsettings.json b/LLama.WebAPI/appsettings.json index 10f68b8c..09cfdb13 100644 --- a/LLama.WebAPI/appsettings.json +++ b/LLama.WebAPI/appsettings.json @@ -5,5 +5,6 @@ "Microsoft.AspNetCore": "Warning" } }, - "AllowedHosts": "*" + "AllowedHosts": "*", + "ModelPath": "..\\..\\LLamaModel\\ggml-model-f32-q4_0.bin" }