From c9108f83117d8cd237464648f73ef2e1c094ef75 Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Wed, 4 Oct 2023 10:40:53 +1300 Subject: [PATCH 1/7] Add service for managing Models and Model Contexts --- LLama.Web/Async/AsyncLock.cs | 55 ++++++++ LLama.Web/Common/LLamaOptions.cs | 1 + LLama.Web/Common/ModelLoadType.cs | 30 ++++ LLama.Web/Common/ModelOptions.cs | 206 +++++++++++++++------------- LLama.Web/LLamaModel.cs | 106 ++++++++++++++ LLama.Web/Services/IModelService.cs | 76 ++++++++++ LLama.Web/Services/ModelService.cs | 202 +++++++++++++++++++++++++++ 7 files changed, 582 insertions(+), 94 deletions(-) create mode 100644 LLama.Web/Async/AsyncLock.cs create mode 100644 LLama.Web/Common/ModelLoadType.cs create mode 100644 LLama.Web/LLamaModel.cs create mode 100644 LLama.Web/Services/IModelService.cs create mode 100644 LLama.Web/Services/ModelService.cs diff --git a/LLama.Web/Async/AsyncLock.cs b/LLama.Web/Async/AsyncLock.cs new file mode 100644 index 00000000..09ccb0f7 --- /dev/null +++ b/LLama.Web/Async/AsyncLock.cs @@ -0,0 +1,55 @@ +namespace LLama.Web.Async +{ + /// + /// Create an Async locking using statment + /// + public sealed class AsyncLock + { + private readonly SemaphoreSlim _semaphore; + private readonly Task _releaser; + + + /// + /// Initializes a new instance of the class. + /// + public AsyncLock() + { + _semaphore = new SemaphoreSlim(1, 1); + _releaser = Task.FromResult((IDisposable)new Releaser(this)); + } + + + /// + /// Locks the using statement asynchronously. + /// + /// + public Task LockAsync() + { + var wait = _semaphore.WaitAsync(); + if (wait.IsCompleted) + return _releaser; + + return wait.ContinueWith((_, state) => (IDisposable)state, _releaser.Result, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); + } + + + /// + /// IDisposable wrapper class to release the lock on dispose + /// + /// + private sealed class Releaser : IDisposable + { + private readonly AsyncLock _lockToRelease; + + internal Releaser(AsyncLock lockToRelease) + { + _lockToRelease = lockToRelease; + } + + public void Dispose() + { + _lockToRelease._semaphore.Release(); + } + } + } +} diff --git a/LLama.Web/Common/LLamaOptions.cs b/LLama.Web/Common/LLamaOptions.cs index 1ac0d829..a64b9635 100644 --- a/LLama.Web/Common/LLamaOptions.cs +++ b/LLama.Web/Common/LLamaOptions.cs @@ -2,6 +2,7 @@ { public class LLamaOptions { + public ModelLoadType ModelLoadType { get; set; } public List Models { get; set; } public List Prompts { get; set; } = new List(); public List Parameters { get; set; } = new List(); diff --git a/LLama.Web/Common/ModelLoadType.cs b/LLama.Web/Common/ModelLoadType.cs new file mode 100644 index 00000000..9e1c77b7 --- /dev/null +++ b/LLama.Web/Common/ModelLoadType.cs @@ -0,0 +1,30 @@ +namespace LLama.Web.Common +{ + /// + /// The type of model load caching to use + /// + public enum ModelLoadType + { + + /// + /// Only one model will be loaded into memory at a time, any other models will be unloaded before the new one is loaded + /// + Single = 0, + + /// + /// Multiple models will be loaded into memory, ensure you use the ModelConfigs to split the hardware resources + /// + Multiple = 1, + + /// + /// The first model in the appsettings.json list will be preloaded into memory at app startup + /// + PreloadSingle = 2, + + + /// + /// All models in the appsettings.json list will be preloaded into memory at app startup, ensure you use the ModelConfigs to split the hardware resources + /// + PreloadMultiple = 3, + } +} diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs index f06757e3..c6cf0988 100644 --- a/LLama.Web/Common/ModelOptions.cs +++ b/LLama.Web/Common/ModelOptions.cs @@ -3,105 +3,123 @@ using LLama.Abstractions; namespace LLama.Web.Common { - public class ModelOptions - : IModelParams + public class ModelOptions : IModelParams { - + /// + /// Model friendly name + /// public string Name { get; set; } + + /// + /// Max context insta=nces allowed per model + /// public int MaxInstances { get; set; } + /// + /// Model context size (n_ctx) + /// + public int ContextSize { get; set; } = 512; + + /// + /// the GPU that is used for scratch and small tensors + /// + public int MainGpu { get; set; } = 0; + + /// + /// if true, reduce VRAM usage at the cost of performance + /// + public bool LowVram { get; set; } = false; + + /// + /// Number of layers to run in VRAM / GPU memory (n_gpu_layers) + /// + public int GpuLayerCount { get; set; } = 20; + + /// + /// Seed for the random number generator (seed) + /// + public int Seed { get; set; } = 1686349486; + + /// + /// Use f16 instead of f32 for memory kv (memory_f16) + /// + public bool UseFp16Memory { get; set; } = true; + + /// + /// Use mmap for faster loads (use_mmap) + /// + public bool UseMemorymap { get; set; } = true; + + /// + /// Use mlock to keep model in memory (use_mlock) + /// + public bool UseMemoryLock { get; set; } = false; + + /// + /// Compute perplexity over the prompt (perplexity) + /// + public bool Perplexity { get; set; } = false; + + /// + /// Model path (model) + /// + public string ModelPath { get; set; } + + /// + /// model alias + /// + public string ModelAlias { get; set; } = "unknown"; - /// - /// Model context size (n_ctx) - /// - public int ContextSize { get; set; } = 512; - /// - /// the GPU that is used for scratch and small tensors - /// - public int MainGpu { get; set; } = 0; - /// - /// if true, reduce VRAM usage at the cost of performance - /// - public bool LowVram { get; set; } = false; - /// - /// Number of layers to run in VRAM / GPU memory (n_gpu_layers) - /// - public int GpuLayerCount { get; set; } = 20; - /// - /// Seed for the random number generator (seed) - /// - public int Seed { get; set; } = 1686349486; - /// - /// Use f16 instead of f32 for memory kv (memory_f16) - /// - public bool UseFp16Memory { get; set; } = true; - /// - /// Use mmap for faster loads (use_mmap) - /// - public bool UseMemorymap { get; set; } = true; - /// - /// Use mlock to keep model in memory (use_mlock) - /// - public bool UseMemoryLock { get; set; } = false; - /// - /// Compute perplexity over the prompt (perplexity) - /// - public bool Perplexity { get; set; } = false; - /// - /// Model path (model) - /// - public string ModelPath { get; set; } - /// - /// model alias - /// - public string ModelAlias { get; set; } = "unknown"; - /// - /// lora adapter path (lora_adapter) - /// - public string LoraAdapter { get; set; } = string.Empty; - /// - /// base model path for the lora adapter (lora_base) - /// - public string LoraBase { get; set; } = string.Empty; - /// - /// Number of threads (-1 = autodetect) (n_threads) - /// - public int Threads { get; set; } = Math.Max(Environment.ProcessorCount / 2, 1); - /// - /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch) - /// - public int BatchSize { get; set; } = 512; - - /// - /// Whether to convert eos to newline during the inference. - /// - public bool ConvertEosToNewLine { get; set; } = false; - - /// - /// Whether to use embedding mode. (embedding) Note that if this is set to true, - /// The LLamaModel won't produce text response anymore. - /// - public bool EmbeddingMode { get; set; } = false; - - /// - /// how split tensors should be distributed across GPUs - /// - public float[] TensorSplits { get; set; } - - /// - /// RoPE base frequency - /// - public float RopeFrequencyBase { get; set; } = 10000.0f; - - /// - /// RoPE frequency scaling factor - /// - public float RopeFrequencyScale { get; set; } = 1.0f; - - /// - /// Use experimental mul_mat_q kernels - /// - public bool MulMatQ { get; set; } + /// + /// lora adapter path (lora_adapter) + /// + public string LoraAdapter { get; set; } = string.Empty; + + /// + /// base model path for the lora adapter (lora_base) + /// + public string LoraBase { get; set; } = string.Empty; + + /// + /// Number of threads (-1 = autodetect) (n_threads) + /// + public int Threads { get; set; } = Math.Max(Environment.ProcessorCount / 2, 1); + + /// + /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch) + /// + public int BatchSize { get; set; } = 512; + + /// + /// Whether to convert eos to newline during the inference. + /// + public bool ConvertEosToNewLine { get; set; } = false; + + /// + /// Whether to use embedding mode. (embedding) Note that if this is set to true, + /// The LLamaModel won't produce text response anymore. + /// + public bool EmbeddingMode { get; set; } = false; + + /// + /// how split tensors should be distributed across GPUs + /// + public float[] TensorSplits { get; set; } + + /// + /// RoPE base frequency + /// + public float RopeFrequencyBase { get; set; } = 10000.0f; + + /// + /// RoPE frequency scaling factor + /// + public float RopeFrequencyScale { get; set; } = 1.0f; + + /// + /// Use experimental mul_mat_q kernels + /// + public bool MulMatQ { get; set; } /// /// The encoding to use for models diff --git a/LLama.Web/LLamaModel.cs b/LLama.Web/LLamaModel.cs new file mode 100644 index 00000000..e500ba04 --- /dev/null +++ b/LLama.Web/LLamaModel.cs @@ -0,0 +1,106 @@ +using LLama.Abstractions; +using LLama.Web.Common; +using System.Collections.Concurrent; + +namespace LLama.Web +{ + /// + /// Wrapper class for LLamaSharp LLamaWeights + /// + /// + public class LLamaModel : IDisposable + { + private readonly ModelOptions _config; + private readonly LLamaWeights _weights; + private readonly ConcurrentDictionary _contexts; + + /// + /// Initializes a new instance of the class. + /// + /// The model parameters. + public LLamaModel(ModelOptions modelParams) + { + _config = modelParams; + _weights = LLamaWeights.LoadFromFile(modelParams); + _contexts = new ConcurrentDictionary(); + } + + /// + /// Gets the model configuration. + /// + public IModelParams ModelParams => _config; + + /// + /// Gets the LLamaWeights + /// + public LLamaWeights LLamaWeights => _weights; + + + /// + /// Gets the context count. + /// + public int ContextCount => _contexts.Count; + + + /// + /// Creates a new context session on this model + /// + /// The unique context identifier + /// LLamaModelContext for this LLamaModel + /// Context exists + public Task CreateContext(string contextName) + { + if (_contexts.TryGetValue(contextName, out var context)) + throw new Exception($"Context with id {contextName} already exists."); + + if (_config.MaxInstances > -1 && ContextCount >= _config.MaxInstances) + throw new Exception($"Maximum model instances reached"); + + context = _weights.CreateContext(_config); + if (_contexts.TryAdd(contextName, context)) + return Task.FromResult(context); + + return Task.FromResult(null); + } + + /// + /// Get a contexts belonging to this model + /// + /// The unique context identifier + /// LLamaModelContext for this LLamaModel with the specified contextName + public Task GetContext(string contextName) + { + if (_contexts.TryGetValue(contextName, out var context)) + return Task.FromResult(context); + + return Task.FromResult(null); + } + + /// + /// Remove a context from this model + /// + /// The unique context identifier + /// true if removed, otherwise false + public Task RemoveContext(string contextName) + { + if (!_contexts.TryRemove(contextName, out var context)) + return Task.FromResult(false); + + context?.Dispose(); + return Task.FromResult(true); + } + + + /// + /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources. + /// + public void Dispose() + { + foreach (var context in _contexts.Values) + { + context?.Dispose(); + } + _weights.Dispose(); + } + } +} diff --git a/LLama.Web/Services/IModelService.cs b/LLama.Web/Services/IModelService.cs new file mode 100644 index 00000000..0a98f8f4 --- /dev/null +++ b/LLama.Web/Services/IModelService.cs @@ -0,0 +1,76 @@ +using LLama.Web.Common; + +namespace LLama.Web.Services +{ + /// + /// Service for managing language Models + /// + public interface IModelService + { + /// + /// Gets the model with the specified name. + /// + /// Name of the model. + Task GetModel(string modelName); + + + /// + /// Loads a model from a ModelConfig object. + /// + /// The model configuration. + Task LoadModel(ModelOptions modelOptions); + + + /// + /// Loads all models found in appsettings.json + /// + Task LoadModels(); + + + /// + /// Unloads the model with the specified name. + /// + /// Name of the model. + Task UnloadModel(string modelName); + + + /// + /// Unloads all models. + /// + Task UnloadModels(); + + + /// + /// Gets a context with the specified identifier + /// + /// Name of the model. + /// The context identifier. + Task GetContext(string modelName, string contextName); + + + /// + /// Removes the context. + /// + /// Name of the model. + /// The context identifier. + Task RemoveContext(string modelName, string contextName); + + + /// + /// Creates a context. + /// + /// Name of the model. + /// The context identifier. + Task CreateContext(string modelName, string contextName); + + + /// + /// Gets the or create model and context. + /// This will load a model from disk if not already loaded, and also create the context + /// + /// Name of the model. + /// The context identifier. + /// Both loaded Model and Context + Task<(LLamaModel, LLamaContext)> GetOrCreateModelAndContext(string modelName, string contextName); + } +} \ No newline at end of file diff --git a/LLama.Web/Services/ModelService.cs b/LLama.Web/Services/ModelService.cs new file mode 100644 index 00000000..16365a5d --- /dev/null +++ b/LLama.Web/Services/ModelService.cs @@ -0,0 +1,202 @@ +using LLama.Web.Async; +using LLama.Web.Common; +using System.Collections.Concurrent; + +namespace LLama.Web.Services +{ + + /// + /// Sercive for handling Models,Weights & Contexts + /// + public class ModelService : IModelService + { + private readonly AsyncLock _modelLock; + private readonly AsyncLock _contextLock; + private readonly LLamaOptions _configuration; + private readonly ConcurrentDictionary _modelInstances; + + + /// + /// Initializes a new instance of the class. + /// + /// The logger. + /// The options. + public ModelService(LLamaOptions configuration) + { + _modelLock = new AsyncLock(); + _contextLock = new AsyncLock(); + _configuration = configuration; + _modelInstances = new ConcurrentDictionary(); + } + + + /// + /// Loads a model with the provided configuration. + /// + /// The model configuration. + /// + public async Task LoadModel(ModelOptions modelOptions) + { + if (_modelInstances.TryGetValue(modelOptions.Name, out var existingModel)) + return existingModel; + + using (await _modelLock.LockAsync()) + { + if (_modelInstances.TryGetValue(modelOptions.Name, out var model)) + return model; + + // If in single mode unload any other models + if (_configuration.ModelLoadType == ModelLoadType.Single + || _configuration.ModelLoadType == ModelLoadType.PreloadSingle) + await UnloadModels(); + + + model = new LLamaModel(modelOptions); + _modelInstances.TryAdd(modelOptions.Name, model); + return model; + } + } + + + /// + /// Loads the models. + /// + public async Task LoadModels() + { + if (_configuration.ModelLoadType == ModelLoadType.Single + || _configuration.ModelLoadType == ModelLoadType.Multiple) + return; + + foreach (var modelConfig in _configuration.Models) + { + await LoadModel(modelConfig); + + //Only preload first model if in SinglePreload mode + if (_configuration.ModelLoadType == ModelLoadType.PreloadSingle) + break; + } + } + + + /// + /// Unloads the model. + /// + /// Name of the model. + /// + public Task UnloadModel(string modelName) + { + if (_modelInstances.TryRemove(modelName, out var model)) + { + model?.Dispose(); + return Task.FromResult(true); + } + return Task.FromResult(false); + } + + + + /// + /// Unloads all models. + /// + public async Task UnloadModels() + { + foreach (var modelName in _modelInstances.Keys) + { + await UnloadModel(modelName); + } + } + + + /// + /// Gets a model ny name. + /// + /// Name of the model. + /// + public Task GetModel(string modelName) + { + _modelInstances.TryGetValue(modelName, out var model); + return Task.FromResult(model); + } + + + /// + /// Gets a context from the specified model. + /// + /// Name of the model. + /// The contextName. + /// + /// Model not found + public async Task GetContext(string modelName, string contextName) + { + if (!_modelInstances.TryGetValue(modelName, out var model)) + throw new Exception("Model not found"); + + return await model.GetContext(contextName); + } + + + /// + /// Creates a context on the specified model. + /// + /// Name of the model. + /// The contextName. + /// + /// Model not found + public async Task CreateContext(string modelName, string contextName) + { + if (!_modelInstances.TryGetValue(modelName, out var model)) + throw new Exception("Model not found"); + + using (await _contextLock.LockAsync()) + { + return await model.CreateContext(contextName); + } + } + + + /// + /// Removes a context from the specified model. + /// + /// Name of the model. + /// The contextName. + /// + /// Model not found + public async Task RemoveContext(string modelName, string contextName) + { + if (!_modelInstances.TryGetValue(modelName, out var model)) + throw new Exception("Model not found"); + + using (await _contextLock.LockAsync()) + { + return await model.RemoveContext(contextName); + } + } + + + /// + /// Loads, Gets,Creates a Model and a Context + /// + /// Name of the model. + /// The contextName. + /// + /// Model option '{modelName}' not found + public async Task<(LLamaModel, LLamaContext)> GetOrCreateModelAndContext(string modelName, string contextName) + { + if (_modelInstances.TryGetValue(modelName, out var model)) + return (model, await model.GetContext(contextName) ?? await model.CreateContext(contextName)); + + + // Get model configuration + var modelConfig = _configuration.Models.FirstOrDefault(x => x.Name == modelName); + if (modelConfig is null) + throw new Exception($"Model option '{modelName}' not found"); + + // Load Model + model = await LoadModel(modelConfig); + + // Get or Create Context + return (model, await model.GetContext(contextName) ?? await model.CreateContext(contextName)); + } + + } +} From 44f1b91c292eba68df285005e2e763484a45fc85 Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Wed, 4 Oct 2023 12:57:15 +1300 Subject: [PATCH 2/7] Update Web to support version 0.5.1 --- LLama.Web/Async/AsyncGuard.cs | 107 +++++++++ LLama.Web/Common/InferenceOptions.cs | 101 ++++++++ LLama.Web/Common/LLamaOptions.cs | 9 - LLama.Web/Common/ParameterOptions.cs | 105 --------- LLama.Web/Common/PromptOptions.cs | 11 - LLama.Web/Common/SessionOptions.cs | 14 ++ LLama.Web/Extensioms.cs | 54 +++++ LLama.Web/Hubs/ISessionClient.cs | 1 - LLama.Web/Hubs/SessionConnectionHub.cs | 57 ++--- LLama.Web/LLama.Web.csproj | 4 + LLama.Web/{ => Models}/LLamaModel.cs | 4 +- LLama.Web/Models/ModelSession.cs | 138 ++++++++--- LLama.Web/Models/ResponseFragment.cs | 18 -- LLama.Web/Models/TokenModel.cs | 24 ++ LLama.Web/Pages/Executor/Instruct.cshtml | 96 -------- LLama.Web/Pages/Executor/Instruct.cshtml.cs | 34 --- LLama.Web/Pages/Executor/Instruct.cshtml.css | 4 - LLama.Web/Pages/Executor/Interactive.cshtml | 96 -------- .../Pages/Executor/Interactive.cshtml.cs | 34 --- .../Pages/Executor/Interactive.cshtml.css | 4 - LLama.Web/Pages/Executor/Stateless.cshtml | 97 -------- LLama.Web/Pages/Executor/Stateless.cshtml.cs | 34 --- LLama.Web/Pages/Executor/Stateless.cshtml.css | 4 - LLama.Web/Pages/Index.cshtml | 119 +++++++++- LLama.Web/Pages/Index.cshtml.cs | 25 +- LLama.Web/Pages/Shared/_ChatTemplates.cshtml | 24 +- LLama.Web/Pages/Shared/_Layout.cshtml | 32 +-- LLama.Web/Pages/Shared/_Parameters.cshtml | 137 +++++++++++ LLama.Web/Program.cs | 5 +- .../Services/ConnectionSessionService.cs | 94 -------- LLama.Web/Services/IModelService.cs | 1 + LLama.Web/Services/IModelSessionService.cs | 84 ++++++- LLama.Web/Services/ModelLoaderService.cs | 42 ++++ LLama.Web/Services/ModelService.cs | 1 + LLama.Web/Services/ModelSessionService.cs | 216 ++++++++++++++++++ LLama.Web/appsettings.json | 60 ++--- LLama.Web/wwwroot/css/site.css | 25 +- LLama.Web/wwwroot/js/sessionConnectionChat.js | 139 +++++++---- LLama.Web/wwwroot/js/site.js | 8 +- 39 files changed, 1208 insertions(+), 854 deletions(-) create mode 100644 LLama.Web/Async/AsyncGuard.cs create mode 100644 LLama.Web/Common/InferenceOptions.cs delete mode 100644 LLama.Web/Common/ParameterOptions.cs delete mode 100644 LLama.Web/Common/PromptOptions.cs create mode 100644 LLama.Web/Common/SessionOptions.cs create mode 100644 LLama.Web/Extensioms.cs rename LLama.Web/{ => Models}/LLamaModel.cs (98%) delete mode 100644 LLama.Web/Models/ResponseFragment.cs create mode 100644 LLama.Web/Models/TokenModel.cs delete mode 100644 LLama.Web/Pages/Executor/Instruct.cshtml delete mode 100644 LLama.Web/Pages/Executor/Instruct.cshtml.cs delete mode 100644 LLama.Web/Pages/Executor/Instruct.cshtml.css delete mode 100644 LLama.Web/Pages/Executor/Interactive.cshtml delete mode 100644 LLama.Web/Pages/Executor/Interactive.cshtml.cs delete mode 100644 LLama.Web/Pages/Executor/Interactive.cshtml.css delete mode 100644 LLama.Web/Pages/Executor/Stateless.cshtml delete mode 100644 LLama.Web/Pages/Executor/Stateless.cshtml.cs delete mode 100644 LLama.Web/Pages/Executor/Stateless.cshtml.css create mode 100644 LLama.Web/Pages/Shared/_Parameters.cshtml delete mode 100644 LLama.Web/Services/ConnectionSessionService.cs create mode 100644 LLama.Web/Services/ModelLoaderService.cs create mode 100644 LLama.Web/Services/ModelSessionService.cs diff --git a/LLama.Web/Async/AsyncGuard.cs b/LLama.Web/Async/AsyncGuard.cs new file mode 100644 index 00000000..ff6b6c43 --- /dev/null +++ b/LLama.Web/Async/AsyncGuard.cs @@ -0,0 +1,107 @@ +using System.Collections.Concurrent; + +namespace LLama.Web.Async +{ + + /// + /// Creates a async/thread-safe guard helper + /// + /// + public class AsyncGuard : AsyncGuard + { + private readonly byte _key; + private readonly ConcurrentDictionary _lockData; + + + /// + /// Initializes a new instance of the class. + /// + public AsyncGuard() + { + _key = 0; + _lockData = new ConcurrentDictionary(); + } + + + /// + /// Guards this instance. + /// + /// true if able to enter an guard, false if already guarded + public bool Guard() + { + return _lockData.TryAdd(_key, true); + } + + + /// + /// Releases the guard. + /// + /// + public bool Release() + { + return _lockData.TryRemove(_key, out _); + } + + + /// + /// Determines whether this instance is guarded. + /// + /// + /// true if this instance is guarded; otherwise, false. + /// + public bool IsGuarded() + { + return _lockData.ContainsKey(_key); + } + } + + + public class AsyncGuard + { + private readonly ConcurrentDictionary _lockData; + + + /// + /// Initializes a new instance of the class. + /// + public AsyncGuard() + { + _lockData = new ConcurrentDictionary(); + } + + + /// + /// Guards the specified value. + /// + /// The value. + /// true if able to enter a guard for this value, false if this value is already guarded + public bool Guard(T value) + { + return _lockData.TryAdd(value, true); + } + + + /// + /// Releases the guard on the specified value. + /// + /// The value. + /// + public bool Release(T value) + { + return _lockData.TryRemove(value, out _); + } + + + /// + /// Determines whether the specified value is guarded. + /// + /// The value. + /// + /// true if the specified value is guarded; otherwise, false. + /// + public bool IsGuarded(T value) + { + return _lockData.ContainsKey(value); + } + } +} diff --git a/LLama.Web/Common/InferenceOptions.cs b/LLama.Web/Common/InferenceOptions.cs new file mode 100644 index 00000000..c2420af3 --- /dev/null +++ b/LLama.Web/Common/InferenceOptions.cs @@ -0,0 +1,101 @@ +using LLama.Common; +using LLama.Abstractions; +using LLama.Native; + +namespace LLama.Web.Common +{ + public class InferenceOptions : IInferenceParams + { + /// + /// number of tokens to keep from initial prompt + /// + public int TokensKeep { get; set; } = 0; + /// + /// how many new tokens to predict (n_predict), set to -1 to inifinitely generate response + /// until it complete. + /// + public int MaxTokens { get; set; } = -1; + /// + /// logit bias for specific tokens + /// + public Dictionary? LogitBias { get; set; } = null; + + /// + /// Sequences where the model will stop generating further tokens. + /// + public IEnumerable AntiPrompts { get; set; } = Array.Empty(); + /// + /// path to file for saving/loading model eval state + /// + public string PathSession { get; set; } = string.Empty; + /// + /// string to suffix user inputs with + /// + public string InputSuffix { get; set; } = string.Empty; + /// + /// string to prefix user inputs with + /// + public string InputPrefix { get; set; } = string.Empty; + /// + /// 0 or lower to use vocab size + /// + public int TopK { get; set; } = 40; + /// + /// 1.0 = disabled + /// + public float TopP { get; set; } = 0.95f; + /// + /// 1.0 = disabled + /// + public float TfsZ { get; set; } = 1.0f; + /// + /// 1.0 = disabled + /// + public float TypicalP { get; set; } = 1.0f; + /// + /// 1.0 = disabled + /// + public float Temperature { get; set; } = 0.8f; + /// + /// 1.0 = disabled + /// + public float RepeatPenalty { get; set; } = 1.1f; + /// + /// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n) + /// + public int RepeatLastTokensCount { get; set; } = 64; + /// + /// frequency penalty coefficient + /// 0.0 = disabled + /// + public float FrequencyPenalty { get; set; } = .0f; + /// + /// presence penalty coefficient + /// 0.0 = disabled + /// + public float PresencePenalty { get; set; } = .0f; + /// + /// Mirostat uses tokens instead of words. + /// algorithm described in the paper https://arxiv.org/abs/2007.14966. + /// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + /// + public MirostatType Mirostat { get; set; } = MirostatType.Disable; + /// + /// target entropy + /// + public float MirostatTau { get; set; } = 5.0f; + /// + /// learning rate + /// + public float MirostatEta { get; set; } = 0.1f; + /// + /// consider newlines as a repeatable token (penalize_nl) + /// + public bool PenalizeNL { get; set; } = true; + + /// + /// A grammar to constrain possible tokens + /// + public SafeLLamaGrammarHandle Grammar { get; set; } = null; + } +} diff --git a/LLama.Web/Common/LLamaOptions.cs b/LLama.Web/Common/LLamaOptions.cs index a64b9635..4a1d6e0a 100644 --- a/LLama.Web/Common/LLamaOptions.cs +++ b/LLama.Web/Common/LLamaOptions.cs @@ -4,18 +4,9 @@ { public ModelLoadType ModelLoadType { get; set; } public List Models { get; set; } - public List Prompts { get; set; } = new List(); - public List Parameters { get; set; } = new List(); public void Initialize() { - foreach (var prompt in Prompts) - { - if (File.Exists(prompt.Path)) - { - prompt.Prompt = File.ReadAllText(prompt.Path).Trim(); - } - } } } } diff --git a/LLama.Web/Common/ParameterOptions.cs b/LLama.Web/Common/ParameterOptions.cs deleted file mode 100644 index f78aa861..00000000 --- a/LLama.Web/Common/ParameterOptions.cs +++ /dev/null @@ -1,105 +0,0 @@ -using LLama.Common; -using LLama.Abstractions; -using LLama.Native; - -namespace LLama.Web.Common -{ - public class ParameterOptions : IInferenceParams - { - public string Name { get; set; } - - - - /// - /// number of tokens to keep from initial prompt - /// - public int TokensKeep { get; set; } = 0; - /// - /// how many new tokens to predict (n_predict), set to -1 to inifinitely generate response - /// until it complete. - /// - public int MaxTokens { get; set; } = -1; - /// - /// logit bias for specific tokens - /// - public Dictionary? LogitBias { get; set; } = null; - - /// - /// Sequences where the model will stop generating further tokens. - /// - public IEnumerable AntiPrompts { get; set; } = Array.Empty(); - /// - /// path to file for saving/loading model eval state - /// - public string PathSession { get; set; } = string.Empty; - /// - /// string to suffix user inputs with - /// - public string InputSuffix { get; set; } = string.Empty; - /// - /// string to prefix user inputs with - /// - public string InputPrefix { get; set; } = string.Empty; - /// - /// 0 or lower to use vocab size - /// - public int TopK { get; set; } = 40; - /// - /// 1.0 = disabled - /// - public float TopP { get; set; } = 0.95f; - /// - /// 1.0 = disabled - /// - public float TfsZ { get; set; } = 1.0f; - /// - /// 1.0 = disabled - /// - public float TypicalP { get; set; } = 1.0f; - /// - /// 1.0 = disabled - /// - public float Temperature { get; set; } = 0.8f; - /// - /// 1.0 = disabled - /// - public float RepeatPenalty { get; set; } = 1.1f; - /// - /// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n) - /// - public int RepeatLastTokensCount { get; set; } = 64; - /// - /// frequency penalty coefficient - /// 0.0 = disabled - /// - public float FrequencyPenalty { get; set; } = .0f; - /// - /// presence penalty coefficient - /// 0.0 = disabled - /// - public float PresencePenalty { get; set; } = .0f; - /// - /// Mirostat uses tokens instead of words. - /// algorithm described in the paper https://arxiv.org/abs/2007.14966. - /// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - /// - public MirostatType Mirostat { get; set; } = MirostatType.Disable; - /// - /// target entropy - /// - public float MirostatTau { get; set; } = 5.0f; - /// - /// learning rate - /// - public float MirostatEta { get; set; } = 0.1f; - /// - /// consider newlines as a repeatable token (penalize_nl) - /// - public bool PenalizeNL { get; set; } = true; - - /// - /// A grammar to constrain possible tokens - /// - public SafeLLamaGrammarHandle Grammar { get; set; } = null; - } -} diff --git a/LLama.Web/Common/PromptOptions.cs b/LLama.Web/Common/PromptOptions.cs deleted file mode 100644 index 4e44a5d1..00000000 --- a/LLama.Web/Common/PromptOptions.cs +++ /dev/null @@ -1,11 +0,0 @@ -namespace LLama.Web.Common -{ - public class PromptOptions - { - public string Name { get; set; } - public string Path { get; set; } - public string Prompt { get; set; } - public List AntiPrompt { get; set; } - public List OutputFilter { get; set; } - } -} diff --git a/LLama.Web/Common/SessionOptions.cs b/LLama.Web/Common/SessionOptions.cs new file mode 100644 index 00000000..34386955 --- /dev/null +++ b/LLama.Web/Common/SessionOptions.cs @@ -0,0 +1,14 @@ +namespace LLama.Web.Common +{ + public class SessionOptions + { + public string Model { get; set; } + public string Prompt { get; set; } + + public string AntiPrompt { get; set; } + public List AntiPrompts { get; set; } + public string OutputFilter { get; set; } + public List OutputFilters { get; set; } + public LLamaExecutorType ExecutorType { get; set; } + } +} diff --git a/LLama.Web/Extensioms.cs b/LLama.Web/Extensioms.cs new file mode 100644 index 00000000..50bb55c4 --- /dev/null +++ b/LLama.Web/Extensioms.cs @@ -0,0 +1,54 @@ +using LLama.Web.Common; + +namespace LLama.Web +{ + public static class Extensioms + { + /// + /// Combines the AntiPrompts list and AntiPrompt csv + /// + /// The session configuration. + /// Combined AntiPrompts with duplicates removed + public static List GetAntiPrompts(this Common.SessionOptions sessionConfig) + { + return CombineCSV(sessionConfig.AntiPrompts, sessionConfig.AntiPrompt); + } + + /// + /// Combines the OutputFilters list and OutputFilter csv + /// + /// The session configuration. + /// Combined OutputFilters with duplicates removed + public static List GetOutputFilters(this Common.SessionOptions sessionConfig) + { + return CombineCSV(sessionConfig.OutputFilters, sessionConfig.OutputFilter); + } + + + /// + /// Combines a string list and a csv and removes duplicates + /// + /// The list. + /// The CSV. + /// Combined list with duplicates removed + private static List CombineCSV(List list, string csv) + { + var results = list?.Count == 0 + ? CommaSeperatedToList(csv) + : CommaSeperatedToList(csv).Concat(list); + return results + .Distinct() + .ToList(); + } + + private static List CommaSeperatedToList(string value) + { + if (string.IsNullOrEmpty(value)) + return new List(); + + return value.Split(",", StringSplitOptions.RemoveEmptyEntries) + .Select(x => x.Trim()) + .ToList(); + } + } +} diff --git a/LLama.Web/Hubs/ISessionClient.cs b/LLama.Web/Hubs/ISessionClient.cs index 9e9dc0f1..92302b21 100644 --- a/LLama.Web/Hubs/ISessionClient.cs +++ b/LLama.Web/Hubs/ISessionClient.cs @@ -6,7 +6,6 @@ namespace LLama.Web.Hubs public interface ISessionClient { Task OnStatus(string connectionId, SessionConnectionStatus status); - Task OnResponse(ResponseFragment fragment); Task OnError(string error); } } diff --git a/LLama.Web/Hubs/SessionConnectionHub.cs b/LLama.Web/Hubs/SessionConnectionHub.cs index 080866c6..730d4e87 100644 --- a/LLama.Web/Hubs/SessionConnectionHub.cs +++ b/LLama.Web/Hubs/SessionConnectionHub.cs @@ -2,16 +2,15 @@ using LLama.Web.Models; using LLama.Web.Services; using Microsoft.AspNetCore.SignalR; -using System.Diagnostics; namespace LLama.Web.Hubs { public class SessionConnectionHub : Hub { private readonly ILogger _logger; - private readonly ConnectionSessionService _modelSessionService; + private readonly IModelSessionService _modelSessionService; - public SessionConnectionHub(ILogger logger, ConnectionSessionService modelSessionService) + public SessionConnectionHub(ILogger logger, IModelSessionService modelSessionService) { _logger = logger; _modelSessionService = modelSessionService; @@ -27,29 +26,27 @@ namespace LLama.Web.Hubs } - public override async Task OnDisconnectedAsync(Exception? exception) + public override async Task OnDisconnectedAsync(Exception exception) { _logger.Log(LogLevel.Information, "[OnDisconnectedAsync], Id: {0}", Context.ConnectionId); // Remove connections session on dissconnect - await _modelSessionService.RemoveAsync(Context.ConnectionId); + await _modelSessionService.CloseAsync(Context.ConnectionId); await base.OnDisconnectedAsync(exception); } [HubMethodName("LoadModel")] - public async Task OnLoadModel(LLamaExecutorType executorType, string modelName, string promptName, string parameterName) + public async Task OnLoadModel(Common.SessionOptions sessionConfig, InferenceOptions inferenceConfig) { - _logger.Log(LogLevel.Information, "[OnLoadModel] - Load new model, Connection: {0}, Model: {1}, Prompt: {2}, Parameter: {3}", Context.ConnectionId, modelName, promptName, parameterName); - - // Remove existing connections session - await _modelSessionService.RemoveAsync(Context.ConnectionId); + _logger.Log(LogLevel.Information, "[OnLoadModel] - Load new model, Connection: {0}", Context.ConnectionId); + await _modelSessionService.CloseAsync(Context.ConnectionId); // Create model session - var modelSessionResult = await _modelSessionService.CreateAsync(executorType, Context.ConnectionId, modelName, promptName, parameterName); - if (modelSessionResult.HasError) + var modelSession = await _modelSessionService.CreateAsync(Context.ConnectionId, sessionConfig, inferenceConfig); + if (modelSession is null) { - await Clients.Caller.OnError(modelSessionResult.Error); + await Clients.Caller.OnError("Failed to create model session"); return; } @@ -59,40 +56,12 @@ namespace LLama.Web.Hubs [HubMethodName("SendPrompt")] - public async Task OnSendPrompt(string prompt) + public IAsyncEnumerable OnSendPrompt(string prompt, InferenceOptions inferConfig, CancellationToken cancellationToken) { _logger.Log(LogLevel.Information, "[OnSendPrompt] - New prompt received, Connection: {0}", Context.ConnectionId); - // Get connections session - var modelSession = await _modelSessionService.GetAsync(Context.ConnectionId); - if (modelSession is null) - { - await Clients.Caller.OnError("No model has been loaded"); - return; - } - - - // Create unique response id - var responseId = Guid.NewGuid().ToString(); - - // Send begin of response - await Clients.Caller.OnResponse(new ResponseFragment(responseId, isFirst: true)); - - // Send content of response - var stopwatch = Stopwatch.GetTimestamp(); - await foreach (var fragment in modelSession.InferAsync(prompt, CancellationTokenSource.CreateLinkedTokenSource(Context.ConnectionAborted))) - { - await Clients.Caller.OnResponse(new ResponseFragment(responseId, fragment)); - } - - // Send end of response - var elapsedTime = Stopwatch.GetElapsedTime(stopwatch); - var signature = modelSession.IsInferCanceled() - ? $"Inference cancelled after {elapsedTime.TotalSeconds:F0} seconds" - : $"Inference completed in {elapsedTime.TotalSeconds:F0} seconds"; - await Clients.Caller.OnResponse(new ResponseFragment(responseId, signature, isLast: true)); - _logger.Log(LogLevel.Information, "[OnSendPrompt] - Inference complete, Connection: {0}, Elapsed: {1}, Canceled: {2}", Context.ConnectionId, elapsedTime, modelSession.IsInferCanceled()); + var linkedCancelationToken = CancellationTokenSource.CreateLinkedTokenSource(Context.ConnectionAborted, cancellationToken); + return _modelSessionService.InferAsync(Context.ConnectionId, prompt, inferConfig, linkedCancelationToken.Token); } - } } diff --git a/LLama.Web/LLama.Web.csproj b/LLama.Web/LLama.Web.csproj index d0e15a62..5a46c5e8 100644 --- a/LLama.Web/LLama.Web.csproj +++ b/LLama.Web/LLama.Web.csproj @@ -14,4 +14,8 @@ + + + + diff --git a/LLama.Web/LLamaModel.cs b/LLama.Web/Models/LLamaModel.cs similarity index 98% rename from LLama.Web/LLamaModel.cs rename to LLama.Web/Models/LLamaModel.cs index e500ba04..71bb290e 100644 --- a/LLama.Web/LLamaModel.cs +++ b/LLama.Web/Models/LLamaModel.cs @@ -2,12 +2,12 @@ using LLama.Web.Common; using System.Collections.Concurrent; -namespace LLama.Web +namespace LLama.Web.Models { /// /// Wrapper class for LLamaSharp LLamaWeights /// - /// + /// public class LLamaModel : IDisposable { private readonly ModelOptions _config; diff --git a/LLama.Web/Models/ModelSession.cs b/LLama.Web/Models/ModelSession.cs index c53676f2..35413f92 100644 --- a/LLama.Web/Models/ModelSession.cs +++ b/LLama.Web/Models/ModelSession.cs @@ -3,46 +3,97 @@ using LLama.Web.Common; namespace LLama.Web.Models { - public class ModelSession : IDisposable + public class ModelSession { - private bool _isFirstInteraction = true; - private ModelOptions _modelOptions; - private PromptOptions _promptOptions; - private ParameterOptions _inferenceOptions; - private ITextStreamTransform _outputTransform; - private ILLamaExecutor _executor; + private readonly string _sessionId; + private readonly LLamaModel _model; + private readonly LLamaContext _context; + private readonly ILLamaExecutor _executor; + private readonly Common.SessionOptions _sessionParams; + private readonly ITextStreamTransform _outputTransform; + private readonly InferenceOptions _defaultInferenceConfig; + private CancellationTokenSource _cancellationTokenSource; - public ModelSession(ILLamaExecutor executor, ModelOptions modelOptions, PromptOptions promptOptions, ParameterOptions parameterOptions) + public ModelSession(LLamaModel model, LLamaContext context, string sessionId, Common.SessionOptions sessionOptions, InferenceOptions inferenceOptions = null) { - _executor = executor; - _modelOptions = modelOptions; - _promptOptions = promptOptions; - _inferenceOptions = parameterOptions; - - _inferenceOptions.AntiPrompts = _promptOptions.AntiPrompt?.Concat(_inferenceOptions.AntiPrompts ?? Enumerable.Empty()).Distinct() ?? _inferenceOptions.AntiPrompts; - if (_promptOptions.OutputFilter?.Count > 0) - _outputTransform = new LLamaTransforms.KeywordTextOutputStreamTransform(_promptOptions.OutputFilter, redundancyLength: 5); + _model = model; + _context = context; + _sessionId = sessionId; + _sessionParams = sessionOptions; + _defaultInferenceConfig = inferenceOptions ?? new InferenceOptions(); + _outputTransform = CreateOutputFilter(_sessionParams); + _executor = CreateExecutor(_model, _context, _sessionParams); } - public string ModelName + /// + /// Gets the session identifier. + /// + public string SessionId => _sessionId; + + /// + /// Gets the name of the model. + /// + public string ModelName => _sessionParams.Model; + + /// + /// Gets the context. + /// + public LLamaContext Context => _context; + + /// + /// Gets the session configuration. + /// + public Common.SessionOptions SessionConfig => _sessionParams; + + /// + /// Gets the inference parameters. + /// + public InferenceOptions InferenceParams => _defaultInferenceConfig; + + + + /// + /// Initializes the prompt. + /// + /// The inference configuration. + /// The cancellation token. + internal async Task InitializePrompt(InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) { - get { return _modelOptions.Name; } + if (_sessionParams.ExecutorType == LLamaExecutorType.Stateless) + return; + + if (string.IsNullOrEmpty(_sessionParams.Prompt)) + return; + + // Run Initial prompt + var inferenceParams = ConfigureInferenceParams(inferenceConfig); + _cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + await foreach (var _ in _executor.InferAsync(_sessionParams.Prompt, inferenceParams, _cancellationTokenSource.Token)) + { + // We dont really need the response of the initial prompt, so exit on first token + break; + }; } - public IAsyncEnumerable InferAsync(string message, CancellationTokenSource cancellationTokenSource) + + /// + /// Runs inference on the model context + /// + /// The message. + /// The inference configuration. + /// The cancellation token. + /// + internal IAsyncEnumerable InferAsync(string message, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) { - _cancellationTokenSource = cancellationTokenSource; - if (_isFirstInteraction) - { - _isFirstInteraction = false; - message = _promptOptions.Prompt + message; - } + var inferenceParams = ConfigureInferenceParams(inferenceConfig); + _cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + var inferenceStream = _executor.InferAsync(message, inferenceParams, _cancellationTokenSource.Token); if (_outputTransform is not null) - return _outputTransform.TransformAsync(_executor.InferAsync(message, _inferenceOptions, _cancellationTokenSource.Token)); + return _outputTransform.TransformAsync(inferenceStream); - return _executor.InferAsync(message, _inferenceOptions, _cancellationTokenSource.Token); + return inferenceStream; } @@ -56,13 +107,36 @@ namespace LLama.Web.Models return _cancellationTokenSource.IsCancellationRequested; } - public void Dispose() + /// + /// Configures the inference parameters. + /// + /// The inference configuration. + private IInferenceParams ConfigureInferenceParams(InferenceOptions inferenceConfig) + { + var inferenceParams = inferenceConfig ?? _defaultInferenceConfig; + inferenceParams.AntiPrompts = _sessionParams.GetAntiPrompts(); + return inferenceParams; + } + + private ITextStreamTransform CreateOutputFilter(Common.SessionOptions sessionConfig) { - _inferenceOptions = null; - _outputTransform = null; + var outputFilters = sessionConfig.GetOutputFilters(); + if (outputFilters.Count > 0) + return new LLamaTransforms.KeywordTextOutputStreamTransform(outputFilters); - _executor?.Context.Dispose(); - _executor = null; + return null; + } + + + private ILLamaExecutor CreateExecutor(LLamaModel model, LLamaContext context, Common.SessionOptions sessionConfig) + { + return sessionConfig.ExecutorType switch + { + LLamaExecutorType.Interactive => new InteractiveExecutor(_context), + LLamaExecutorType.Instruct => new InstructExecutor(_context), + LLamaExecutorType.Stateless => new StatelessExecutor(_model.LLamaWeights, _model.ModelParams), + _ => default + }; } } } diff --git a/LLama.Web/Models/ResponseFragment.cs b/LLama.Web/Models/ResponseFragment.cs deleted file mode 100644 index 02f27f13..00000000 --- a/LLama.Web/Models/ResponseFragment.cs +++ /dev/null @@ -1,18 +0,0 @@ -namespace LLama.Web.Models -{ - public class ResponseFragment - { - public ResponseFragment(string id, string content = null, bool isFirst = false, bool isLast = false) - { - Id = id; - IsLast = isLast; - IsFirst = isFirst; - Content = content; - } - - public string Id { get; set; } - public string Content { get; set; } - public bool IsLast { get; set; } - public bool IsFirst { get; set; } - } -} diff --git a/LLama.Web/Models/TokenModel.cs b/LLama.Web/Models/TokenModel.cs new file mode 100644 index 00000000..c95f9ec6 --- /dev/null +++ b/LLama.Web/Models/TokenModel.cs @@ -0,0 +1,24 @@ +namespace LLama.Web.Models +{ + public class TokenModel + { + public TokenModel(string id, string content = null, TokenType tokenType = TokenType.Content) + { + Id = id; + Content = content; + TokenType = tokenType; + } + + public string Id { get; set; } + public string Content { get; set; } + public TokenType TokenType { get; set; } + } + + public enum TokenType + { + Begin = 0, + Content = 2, + End = 4, + Cancel = 10 + } +} diff --git a/LLama.Web/Pages/Executor/Instruct.cshtml b/LLama.Web/Pages/Executor/Instruct.cshtml deleted file mode 100644 index 9f8cb2d8..00000000 --- a/LLama.Web/Pages/Executor/Instruct.cshtml +++ /dev/null @@ -1,96 +0,0 @@ -@page -@model InstructModel -@{ - -} -@Html.AntiForgeryToken() -
- -
-
-

Instruct

-
- Hub: Disconnected -
-
- -
- Model - -
- -
- Parameters - -
- -
- Prompt - - -
- -
-
-
- -
-
- -
-
-
- -
-
-
-
- -
-
- -
-
- -
-
- - -
-
-
-
- -
-
- -@{ await Html.RenderPartialAsync("_ChatTemplates"); } - -@section Scripts { - - -} \ No newline at end of file diff --git a/LLama.Web/Pages/Executor/Instruct.cshtml.cs b/LLama.Web/Pages/Executor/Instruct.cshtml.cs deleted file mode 100644 index 18a58253..00000000 --- a/LLama.Web/Pages/Executor/Instruct.cshtml.cs +++ /dev/null @@ -1,34 +0,0 @@ -using LLama.Web.Common; -using LLama.Web.Models; -using LLama.Web.Services; -using Microsoft.AspNetCore.Mvc; -using Microsoft.AspNetCore.Mvc.RazorPages; -using Microsoft.Extensions.Options; - -namespace LLama.Web.Pages -{ - public class InstructModel : PageModel - { - private readonly ILogger _logger; - private readonly ConnectionSessionService _modelSessionService; - - public InstructModel(ILogger logger, IOptions options, ConnectionSessionService modelSessionService) - { - _logger = logger; - Options = options.Value; - _modelSessionService = modelSessionService; - } - - public LLamaOptions Options { get; set; } - - public void OnGet() - { - } - - public async Task OnPostCancel(CancelModel model) - { - await _modelSessionService.CancelAsync(model.ConnectionId); - return new JsonResult(default); - } - } -} \ No newline at end of file diff --git a/LLama.Web/Pages/Executor/Instruct.cshtml.css b/LLama.Web/Pages/Executor/Instruct.cshtml.css deleted file mode 100644 index ed9a1d59..00000000 --- a/LLama.Web/Pages/Executor/Instruct.cshtml.css +++ /dev/null @@ -1,4 +0,0 @@ -.section-content { - flex: 1; - overflow-y: scroll; -} diff --git a/LLama.Web/Pages/Executor/Interactive.cshtml b/LLama.Web/Pages/Executor/Interactive.cshtml deleted file mode 100644 index 916b59ca..00000000 --- a/LLama.Web/Pages/Executor/Interactive.cshtml +++ /dev/null @@ -1,96 +0,0 @@ -@page -@model InteractiveModel -@{ - -} -@Html.AntiForgeryToken() -
- -
-
-

Interactive

-
- Hub: Disconnected -
-
- -
- Model - -
- -
- Parameters - -
- -
- Prompt - - -
- -
-
-
- -
-
- -
-
-
- -
-
-
-
- -
-
- -
-
- -
-
- - -
-
-
-
- -
-
- -@{ await Html.RenderPartialAsync("_ChatTemplates");} - -@section Scripts { - - -} \ No newline at end of file diff --git a/LLama.Web/Pages/Executor/Interactive.cshtml.cs b/LLama.Web/Pages/Executor/Interactive.cshtml.cs deleted file mode 100644 index 7179a440..00000000 --- a/LLama.Web/Pages/Executor/Interactive.cshtml.cs +++ /dev/null @@ -1,34 +0,0 @@ -using LLama.Web.Common; -using LLama.Web.Models; -using LLama.Web.Services; -using Microsoft.AspNetCore.Mvc; -using Microsoft.AspNetCore.Mvc.RazorPages; -using Microsoft.Extensions.Options; - -namespace LLama.Web.Pages -{ - public class InteractiveModel : PageModel - { - private readonly ILogger _logger; - private readonly ConnectionSessionService _modelSessionService; - - public InteractiveModel(ILogger logger, IOptions options, ConnectionSessionService modelSessionService) - { - _logger = logger; - Options = options.Value; - _modelSessionService = modelSessionService; - } - - public LLamaOptions Options { get; set; } - - public void OnGet() - { - } - - public async Task OnPostCancel(CancelModel model) - { - await _modelSessionService.CancelAsync(model.ConnectionId); - return new JsonResult(default); - } - } -} \ No newline at end of file diff --git a/LLama.Web/Pages/Executor/Interactive.cshtml.css b/LLama.Web/Pages/Executor/Interactive.cshtml.css deleted file mode 100644 index ed9a1d59..00000000 --- a/LLama.Web/Pages/Executor/Interactive.cshtml.css +++ /dev/null @@ -1,4 +0,0 @@ -.section-content { - flex: 1; - overflow-y: scroll; -} diff --git a/LLama.Web/Pages/Executor/Stateless.cshtml b/LLama.Web/Pages/Executor/Stateless.cshtml deleted file mode 100644 index b5d8eea3..00000000 --- a/LLama.Web/Pages/Executor/Stateless.cshtml +++ /dev/null @@ -1,97 +0,0 @@ -@page -@model StatelessModel -@{ - -} -@Html.AntiForgeryToken() -
- -
-
-

Stateless

-
- Hub: Disconnected -
-
- -
- Model - -
- -
- Parameters - -
- -
- Prompt - - -
- -
-
-
- -
-
- -
-
-
- -
-
-
-
- -
-
- -
-
- -
-
- - -
-
-
-
- -
-
- -@{ await Html.RenderPartialAsync("_ChatTemplates"); } - - -@section Scripts { - - -} \ No newline at end of file diff --git a/LLama.Web/Pages/Executor/Stateless.cshtml.cs b/LLama.Web/Pages/Executor/Stateless.cshtml.cs deleted file mode 100644 index f88c4b83..00000000 --- a/LLama.Web/Pages/Executor/Stateless.cshtml.cs +++ /dev/null @@ -1,34 +0,0 @@ -using LLama.Web.Common; -using LLama.Web.Models; -using LLama.Web.Services; -using Microsoft.AspNetCore.Mvc; -using Microsoft.AspNetCore.Mvc.RazorPages; -using Microsoft.Extensions.Options; - -namespace LLama.Web.Pages -{ - public class StatelessModel : PageModel - { - private readonly ILogger _logger; - private readonly ConnectionSessionService _modelSessionService; - - public StatelessModel(ILogger logger, IOptions options, ConnectionSessionService modelSessionService) - { - _logger = logger; - Options = options.Value; - _modelSessionService = modelSessionService; - } - - public LLamaOptions Options { get; set; } - - public void OnGet() - { - } - - public async Task OnPostCancel(CancelModel model) - { - await _modelSessionService.CancelAsync(model.ConnectionId); - return new JsonResult(default); - } - } -} \ No newline at end of file diff --git a/LLama.Web/Pages/Executor/Stateless.cshtml.css b/LLama.Web/Pages/Executor/Stateless.cshtml.css deleted file mode 100644 index ed9a1d59..00000000 --- a/LLama.Web/Pages/Executor/Stateless.cshtml.css +++ /dev/null @@ -1,4 +0,0 @@ -.section-content { - flex: 1; - overflow-y: scroll; -} diff --git a/LLama.Web/Pages/Index.cshtml b/LLama.Web/Pages/Index.cshtml index b5f0c15f..55512603 100644 --- a/LLama.Web/Pages/Index.cshtml +++ b/LLama.Web/Pages/Index.cshtml @@ -1,10 +1,121 @@ @page +@using LLama.Web.Common; + @model IndexModel @{ - ViewData["Title"] = "Home page"; + ViewData["Title"] = "Inference Demo"; } -
-

Welcome

-

Learn about building Web apps with ASP.NET Core.

+@Html.AntiForgeryToken() +
+ +
+
+
+ @ViewData["Title"] +
+
+ Socket: Disconnected +
+
+ +
+
+
+
+ Model + @Html.DropDownListFor(m => m.SessionOptions.Model, new SelectList(Model.Options.Models, "Name", "Name"), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"}) +
+
+ Inference Type + @Html.DropDownListFor(m => m.SessionOptions.ExecutorType, Html.GetEnumSelectList(), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"}) +
+ + +
+
+
+ +
+
+
+ + +
+
+ +
+
+
+ +
+
+
+
+ +
+
+ +
+
+ +
+
+ + +
+
+
+
+ +
+ +@{ + await Html.RenderPartialAsync("_ChatTemplates"); +} + +@section Scripts { + + +} \ No newline at end of file diff --git a/LLama.Web/Pages/Index.cshtml.cs b/LLama.Web/Pages/Index.cshtml.cs index 477c9bfb..3647dfec 100644 --- a/LLama.Web/Pages/Index.cshtml.cs +++ b/LLama.Web/Pages/Index.cshtml.cs @@ -1,5 +1,7 @@ -using Microsoft.AspNetCore.Mvc; +using LLama.Web.Common; +using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc.RazorPages; +using Microsoft.Extensions.Options; namespace LLama.Web.Pages { @@ -7,14 +9,33 @@ namespace LLama.Web.Pages { private readonly ILogger _logger; - public IndexModel(ILogger logger) + public IndexModel(ILogger logger, IOptions options) { _logger = logger; + Options = options.Value; } + public LLamaOptions Options { get; set; } + + [BindProperty] + public Common.SessionOptions SessionOptions { get; set; } + + [BindProperty] + public InferenceOptions InferenceOptions { get; set; } + public void OnGet() { + SessionOptions = new Common.SessionOptions + { + Prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.", + AntiPrompt = "User:", + // OutputFilter = "User:, Response:" + }; + InferenceOptions = new InferenceOptions + { + Temperature = 0.8f + }; } } } \ No newline at end of file diff --git a/LLama.Web/Pages/Shared/_ChatTemplates.cshtml b/LLama.Web/Pages/Shared/_ChatTemplates.cshtml index 15644012..cd768f1f 100644 --- a/LLama.Web/Pages/Shared/_ChatTemplates.cshtml +++ b/LLama.Web/Pages/Shared/_ChatTemplates.cshtml @@ -12,7 +12,7 @@
- {{text}} + {{text}}
{{date}}
@@ -26,9 +26,7 @@
- - - +
@@ -41,20 +39,6 @@
- \ No newline at end of file diff --git a/LLama.Web/Pages/Shared/_Layout.cshtml b/LLama.Web/Pages/Shared/_Layout.cshtml index 23132bfa..16d6ad52 100644 --- a/LLama.Web/Pages/Shared/_Layout.cshtml +++ b/LLama.Web/Pages/Shared/_Layout.cshtml @@ -3,7 +3,7 @@ - @ViewData["Title"] - LLama.Web + @ViewData["Title"] - LLamaSharp.Web @@ -13,24 +13,26 @@
-
- @RenderBody() -
+
+ @RenderBody() +
- © 2023 - LLama.Web + © 2023 - LLamaSharp.Web
diff --git a/LLama.Web/Pages/Shared/_Parameters.cshtml b/LLama.Web/Pages/Shared/_Parameters.cshtml new file mode 100644 index 00000000..d6e476c4 --- /dev/null +++ b/LLama.Web/Pages/Shared/_Parameters.cshtml @@ -0,0 +1,137 @@ +@page +@using LLama.Common; +@model LLama.Abstractions.IInferenceParams +} + +
+
+ MaxTokens +
+ @Html.TextBoxFor(m => m.MaxTokens, new { @type="range", @class = "slider", min="-1", max="2048", step="1" }) + +
+
+ +
+ TokensKeep +
+ @Html.TextBoxFor(m => m.TokensKeep, new { @type="range", @class = "slider", min="0", max="2048", step="1" }) + +
+
+
+ +
+
+ TopK +
+ @Html.TextBoxFor(m => m.TopK, new { @type="range", @class = "slider", min="-1", max="100", step="1" }) + +
+
+ +
+ TopP +
+ @Html.TextBoxFor(m => m.TopP, new { @type="range", @class = "slider", min="0.0", max="1.0", step="0.01" }) + +
+
+
+ + + +
+
+ TypicalP +
+ @Html.TextBoxFor(m => m.TypicalP, new { @type="range", @class = "slider", min="0.0", max="1.0", step="0.01" }) + +
+
+ +
+ Temperature +
+ @Html.TextBoxFor(m => m.Temperature, new { @type="range", @class = "slider", min="0.0", max="1.5", step="0.01" }) + +
+
+
+ +
+
+ RepeatPenalty +
+ @Html.TextBoxFor(m => m.RepeatPenalty, new { @type="range", @class = "slider", min="0.0", max="2.0", step="0.01" }) + +
+
+ +
+ RepeatLastTokensCount +
+ @Html.TextBoxFor(m => m.RepeatLastTokensCount, new { @type="range", @class = "slider", min="0", max="2048", step="1" }) + +
+
+
+ +
+
+ FrequencyPenalty +
+ @Html.TextBoxFor(m => m.FrequencyPenalty, new { @type="range", @class = "slider", min="0.0", max="1.0", step="0.01" }) + +
+
+ +
+ PresencePenalty +
+ @Html.TextBoxFor(m => m.PresencePenalty, new { @type="range", @class = "slider", min="0.0", max="1.0", step="0.01" }) + +
+
+
+ +
+
+ TfsZ +
+ @Html.TextBoxFor(m => m.TfsZ, new { @type="range", @class = "slider",min="0.0", max="1.0", step="0.01" }) + +
+
+
+ - +
+ + +
+
+
+ + +
+ Sampler Type + @Html.DropDownListFor(m => m.Mirostat, Html.GetEnumSelectList(), new { @class = "form-control form-select" }) +
+ +
+
+ MirostatTau +
+ @Html.TextBoxFor(m => m.MirostatTau, new { @type="range", @class = "slider", min="0.0", max="10.0", step="0.01" }) + +
+
+ +
+ MirostatEta +
+ @Html.TextBoxFor(m => m.MirostatEta, new { @type="range", @class = "slider", min="0.0", max="1.0", step="0.01" }) + +
+
+
\ No newline at end of file diff --git a/LLama.Web/Program.cs b/LLama.Web/Program.cs index 6db653a1..7c4583d2 100644 --- a/LLama.Web/Program.cs +++ b/LLama.Web/Program.cs @@ -1,6 +1,7 @@ using LLama.Web.Common; using LLama.Web.Hubs; using LLama.Web.Services; +using Microsoft.Extensions.DependencyInjection; namespace LLama.Web { @@ -20,7 +21,9 @@ namespace LLama.Web .BindConfiguration(nameof(LLamaOptions)); // Services DI - builder.Services.AddSingleton(); + builder.Services.AddHostedService(); + builder.Services.AddSingleton(); + builder.Services.AddSingleton(); var app = builder.Build(); diff --git a/LLama.Web/Services/ConnectionSessionService.cs b/LLama.Web/Services/ConnectionSessionService.cs deleted file mode 100644 index 7dfcde39..00000000 --- a/LLama.Web/Services/ConnectionSessionService.cs +++ /dev/null @@ -1,94 +0,0 @@ -using LLama.Abstractions; -using LLama.Web.Common; -using LLama.Web.Models; -using Microsoft.Extensions.Options; -using System.Collections.Concurrent; -using System.Drawing; - -namespace LLama.Web.Services -{ - /// - /// Example Service for handling a model session for a websockets connection lifetime - /// Each websocket connection will create its own unique session and context allowing you to use multiple tabs to compare prompts etc - /// - public class ConnectionSessionService : IModelSessionService - { - private readonly LLamaOptions _options; - private readonly ILogger _logger; - private readonly ConcurrentDictionary _modelSessions; - - public ConnectionSessionService(ILogger logger, IOptions options) - { - _logger = logger; - _options = options.Value; - _modelSessions = new ConcurrentDictionary(); - } - - public Task GetAsync(string connectionId) - { - _modelSessions.TryGetValue(connectionId, out var modelSession); - return Task.FromResult(modelSession); - } - - public Task> CreateAsync(LLamaExecutorType executorType, string connectionId, string modelName, string promptName, string parameterName) - { - var modelOption = _options.Models.FirstOrDefault(x => x.Name == modelName); - if (modelOption is null) - return Task.FromResult(ServiceResult.FromError($"Model option '{modelName}' not found")); - - var promptOption = _options.Prompts.FirstOrDefault(x => x.Name == promptName); - if (promptOption is null) - return Task.FromResult(ServiceResult.FromError($"Prompt option '{promptName}' not found")); - - var parameterOption = _options.Parameters.FirstOrDefault(x => x.Name == parameterName); - if (parameterOption is null) - return Task.FromResult(ServiceResult.FromError($"Parameter option '{parameterName}' not found")); - - - //Max instance - var currentInstances = _modelSessions.Count(x => x.Value.ModelName == modelOption.Name); - if (modelOption.MaxInstances > -1 && currentInstances >= modelOption.MaxInstances) - return Task.FromResult(ServiceResult.FromError("Maximum model instances reached")); - - // Create model - var llamaModel = new LLamaContext(modelOption); - - // Create executor - ILLamaExecutor executor = executorType switch - { - LLamaExecutorType.Interactive => new InteractiveExecutor(llamaModel), - LLamaExecutorType.Instruct => new InstructExecutor(llamaModel), - LLamaExecutorType.Stateless => new StatelessExecutor(llamaModel), - _ => default - }; - - // Create session - var modelSession = new ModelSession(executor, modelOption, promptOption, parameterOption); - if (!_modelSessions.TryAdd(connectionId, modelSession)) - return Task.FromResult(ServiceResult.FromError("Failed to create model session")); - - return Task.FromResult(ServiceResult.FromValue(modelSession)); - } - - public Task RemoveAsync(string connectionId) - { - if (_modelSessions.TryRemove(connectionId, out var modelSession)) - { - modelSession.CancelInfer(); - modelSession.Dispose(); - return Task.FromResult(true); - } - return Task.FromResult(false); - } - - public Task CancelAsync(string connectionId) - { - if (_modelSessions.TryGetValue(connectionId, out var modelSession)) - { - modelSession.CancelInfer(); - return Task.FromResult(true); - } - return Task.FromResult(false); - } - } -} diff --git a/LLama.Web/Services/IModelService.cs b/LLama.Web/Services/IModelService.cs index 0a98f8f4..ec9e4233 100644 --- a/LLama.Web/Services/IModelService.cs +++ b/LLama.Web/Services/IModelService.cs @@ -1,4 +1,5 @@ using LLama.Web.Common; +using LLama.Web.Models; namespace LLama.Web.Services { diff --git a/LLama.Web/Services/IModelSessionService.cs b/LLama.Web/Services/IModelSessionService.cs index 4ee0d483..8723d795 100644 --- a/LLama.Web/Services/IModelSessionService.cs +++ b/LLama.Web/Services/IModelSessionService.cs @@ -1,16 +1,88 @@ -using LLama.Abstractions; -using LLama.Web.Common; +using LLama.Web.Common; using LLama.Web.Models; namespace LLama.Web.Services { public interface IModelSessionService { + /// + /// Gets the ModelSession with the specified Id. + /// + /// The session identifier. + /// The ModelSession if exists, otherwise null Task GetAsync(string sessionId); - Task> CreateAsync(LLamaExecutorType executorType, string sessionId, string modelName, string promptName, string parameterName); - Task RemoveAsync(string sessionId); - Task CancelAsync(string sessionId); - } + /// + /// Gets all ModelSessions + /// + /// A collection oa all Model instances + Task> GetAllAsync(); + + + /// + /// Creates a new ModelSession + /// + /// The session identifier. + /// The session configuration. + /// The default inference configuration, will be used for all inference where no infer configuration is supplied. + /// The cancellation token. + /// + /// + /// Session with id {sessionId} already exists + /// or + /// Failed to create model session + /// + Task CreateAsync(string sessionId, Common.SessionOptions sessionOptions, InferenceOptions inferenceOptions = null, CancellationToken cancellationToken = default); + + + /// + /// Closes the session + /// + /// The session identifier. + /// + Task CloseAsync(string sessionId); + + + /// + /// Runs inference on the current ModelSession + /// + /// The session identifier. + /// The prompt. + /// The inference configuration, if null session default is used + /// The cancellation token. + /// Inference is already running for this session + IAsyncEnumerable InferAsync(string sessionId, string prompt, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default); + + /// + /// Runs inference on the current ModelSession + /// + /// The session identifier. + /// The prompt. + /// The inference configuration, if null session default is used + /// The cancellation token. + /// Streaming async result of + /// Inference is already running for this session + IAsyncEnumerable InferTextAsync(string sessionId, string prompt, InferenceOptions inferenceOptions = null, CancellationToken cancellationToken = default); + + + /// + /// Queues inference on the current ModelSession + /// + /// The session identifier. + /// The prompt. + /// The inference configuration, if null session default is used + /// The cancellation token. + /// Completed inference result as string + /// Inference is already running for this session + Task InferTextCompleteAsync(string sessionId, string prompt, InferenceOptions inferenceOptions = null, CancellationToken cancellationToken = default); + + + /// + /// Cancels the current inference action. + /// + /// The session identifier. + /// + Task CancelAsync(string sessionId); + } } diff --git a/LLama.Web/Services/ModelLoaderService.cs b/LLama.Web/Services/ModelLoaderService.cs new file mode 100644 index 00000000..7545885d --- /dev/null +++ b/LLama.Web/Services/ModelLoaderService.cs @@ -0,0 +1,42 @@ +namespace LLama.Web.Services +{ + + /// + /// Service for managing loading/preloading of models at app startup + /// + /// Type used to identify contexts + /// + public class ModelLoaderService : IHostedService + { + private readonly IModelService _modelService; + + /// + /// Initializes a new instance of the class. + /// + /// The model service. + public ModelLoaderService(IModelService modelService) + { + _modelService = modelService; + } + + + /// + /// Triggered when the application host is ready to start the service. + /// + /// Indicates that the start process has been aborted. + public async Task StartAsync(CancellationToken cancellationToken) + { + await _modelService.LoadModels(); + } + + + /// + /// Triggered when the application host is performing a graceful shutdown. + /// + /// Indicates that the shutdown process should no longer be graceful. + public async Task StopAsync(CancellationToken cancellationToken) + { + await _modelService.UnloadModels(); + } + } +} diff --git a/LLama.Web/Services/ModelService.cs b/LLama.Web/Services/ModelService.cs index 16365a5d..2a3d4788 100644 --- a/LLama.Web/Services/ModelService.cs +++ b/LLama.Web/Services/ModelService.cs @@ -1,5 +1,6 @@ using LLama.Web.Async; using LLama.Web.Common; +using LLama.Web.Models; using System.Collections.Concurrent; namespace LLama.Web.Services diff --git a/LLama.Web/Services/ModelSessionService.cs b/LLama.Web/Services/ModelSessionService.cs new file mode 100644 index 00000000..e808e630 --- /dev/null +++ b/LLama.Web/Services/ModelSessionService.cs @@ -0,0 +1,216 @@ +using LLama.Web.Async; +using LLama.Web.Common; +using LLama.Web.Models; +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Runtime.CompilerServices; + +namespace LLama.Web.Services +{ + /// + /// Example Service for handling a model session for a websockets connection lifetime + /// Each websocket connection will create its own unique session and context allowing you to use multiple tabs to compare prompts etc + /// + public class ModelSessionService : IModelSessionService + { + private readonly AsyncGuard _sessionGuard; + private readonly IModelService _modelService; + private readonly ConcurrentDictionary _modelSessions; + + + /// + /// Initializes a new instance of the class. + /// + /// The model service. + /// The model session state service. + public ModelSessionService(IModelService modelService) + { + _modelService = modelService; + _sessionGuard = new AsyncGuard(); + _modelSessions = new ConcurrentDictionary(); + } + + + /// + /// Gets the ModelSession with the specified Id. + /// + /// The session identifier. + /// The ModelSession if exists, otherwise null + public Task GetAsync(string sessionId) + { + return Task.FromResult(_modelSessions.TryGetValue(sessionId, out var session) ? session : null); + } + + + /// + /// Gets all ModelSessions + /// + /// A collection oa all Model instances + public Task> GetAllAsync() + { + return Task.FromResult>(_modelSessions.Values); + } + + + /// + /// Creates a new ModelSession + /// + /// The session identifier. + /// The session configuration. + /// The default inference configuration, will be used for all inference where no infer configuration is supplied. + /// The cancellation token. + /// + /// + /// Session with id {sessionId} already exists + /// or + /// Failed to create model session + /// + public async Task CreateAsync(string sessionId, Common.SessionOptions sessionConfig, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) + { + if (_modelSessions.TryGetValue(sessionId, out _)) + throw new Exception($"Session with id {sessionId} already exists"); + + // Create context + var (model, context) = await _modelService.GetOrCreateModelAndContext(sessionConfig.Model, sessionId); + + // Create session + var modelSession = new ModelSession(model, context, sessionId, sessionConfig, inferenceConfig); + if (!_modelSessions.TryAdd(sessionId, modelSession)) + throw new Exception($"Failed to create model session"); + + // Run initial Prompt + await modelSession.InitializePrompt(inferenceConfig, cancellationToken); + return modelSession; + + } + + + /// + /// Closes the session + /// + /// The session identifier. + /// + public async Task CloseAsync(string sessionId) + { + if (_modelSessions.TryRemove(sessionId, out var modelSession)) + { + modelSession.CancelInfer(); + return await _modelService.RemoveContext(modelSession.ModelName, sessionId); + } + return false; + } + + + /// + /// Runs inference on the current ModelSession + /// + /// The session identifier. + /// The prompt. + /// The inference configuration, if null session default is used + /// The cancellation token. + /// Inference is already running for this session + public async IAsyncEnumerable InferAsync(string sessionId, string prompt, InferenceOptions inferenceConfig = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + if (!_sessionGuard.Guard(sessionId)) + throw new Exception($"Inference is already running for this session"); + + try + { + if (!_modelSessions.TryGetValue(sessionId, out var modelSession)) + yield break; + + // Send begin of response + var stopwatch = Stopwatch.GetTimestamp(); + yield return new TokenModel(default, default, TokenType.Begin); + + // Send content of response + await foreach (var token in modelSession.InferAsync(prompt, inferenceConfig, cancellationToken).ConfigureAwait(false)) + { + yield return new TokenModel(default, token); + } + + // Send end of response + var elapsedTime = GetElapsed(stopwatch); + var endTokenType = modelSession.IsInferCanceled() ? TokenType.Cancel : TokenType.End; + var signature = endTokenType == TokenType.Cancel + ? $"Inference cancelled after {elapsedTime / 1000:F0} seconds" + : $"Inference completed in {elapsedTime / 1000:F0} seconds"; + yield return new TokenModel(default, signature, endTokenType); + } + finally + { + _sessionGuard.Release(sessionId); + } + } + + + /// + /// Runs inference on the current ModelSession + /// + /// The session identifier. + /// The prompt. + /// The inference configuration, if null session default is used + /// The cancellation token. + /// Streaming async result of + /// Inference is already running for this session + public IAsyncEnumerable InferTextAsync(string sessionId, string prompt, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) + { + async IAsyncEnumerable InferTextInternal() + { + await foreach (var token in InferAsync(sessionId, prompt, inferenceConfig, cancellationToken).ConfigureAwait(false)) + { + if (token.TokenType == TokenType.Content) + yield return token.Content; + } + } + return InferTextInternal(); + } + + + /// + /// Runs inference on the current ModelSession + /// + /// The session identifier. + /// The prompt. + /// The inference configuration, if null session default is used + /// The cancellation token. + /// Completed inference result as string + /// Inference is already running for this session + public async Task InferTextCompleteAsync(string sessionId, string prompt, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) + { + var inferResult = await InferAsync(sessionId, prompt, inferenceConfig, cancellationToken) + .Where(x => x.TokenType == TokenType.Content) + .Select(x => x.Content) + .ToListAsync(cancellationToken: cancellationToken); + + return string.Concat(inferResult); + } + + + /// + /// Cancels the current inference action. + /// + /// The session identifier. + /// + public Task CancelAsync(string sessionId) + { + if (_modelSessions.TryGetValue(sessionId, out var modelSession)) + { + modelSession.CancelInfer(); + return Task.FromResult(true); + } + return Task.FromResult(false); + } + + + /// + /// Gets the elapsed time in milliseconds. + /// + /// The timestamp. + /// + private static int GetElapsed(long timestamp) + { + return (int)Stopwatch.GetElapsedTime(timestamp).TotalMilliseconds; + } + } +} diff --git a/LLama.Web/appsettings.json b/LLama.Web/appsettings.json index 9f340a9c..6231b882 100644 --- a/LLama.Web/appsettings.json +++ b/LLama.Web/appsettings.json @@ -7,48 +7,34 @@ }, "AllowedHosts": "*", "LLamaOptions": { + "ModelLoadType": "Single", "Models": [ { "Name": "WizardLM-7B", - "MaxInstances": 2, + "MaxInstances": 20, "ModelPath": "D:\\Repositories\\AI\\Models\\wizardLM-7B.ggmlv3.q4_0.bin", - "ContextSize": 2048 - } - ], - "Parameters": [ - { - "Name": "Default", - "Temperature": 0.6 - } - ], - "Prompts": [ - { - "Name": "None", - "Prompt": "" - }, - { - "Name": "Alpaca", - "Path": "D:\\Repositories\\AI\\Prompts\\alpaca.txt", - "AntiPrompt": [ - "User:" - ], - "OutputFilter": [ - "Response:", - "User:" - ] - }, - { - "Name": "ChatWithBob", - "Path": "D:\\Repositories\\AI\\Prompts\\chat-with-bob.txt", - "AntiPrompt": [ - "User:" - ], - "OutputFilter": [ - "Bob:", - "User:" - ] + "ContextSize": 2048, + "BatchSize": 2048, + "Threads": 4, + "GpuLayerCount": 6, + "UseMemorymap": true, + "UseMemoryLock": false, + "MainGpu": 0, + "LowVram": false, + "Seed": 1686349486, + "UseFp16Memory": true, + "Perplexity": false, + "LoraAdapter": "", + "LoraBase": "", + "EmbeddingMode": false, + "TensorSplits": null, + "GroupedQueryAttention": 1, + "RmsNormEpsilon": 0.000005, + "RopeFrequencyBase": 10000.0, + "RopeFrequencyScale": 1.0, + "MulMatQ": false, + "Encoding": "UTF-8" } ] - } } diff --git a/LLama.Web/wwwroot/css/site.css b/LLama.Web/wwwroot/css/site.css index d10ef975..14685f45 100644 --- a/LLama.Web/wwwroot/css/site.css +++ b/LLama.Web/wwwroot/css/site.css @@ -22,13 +22,30 @@ footer { @media (min-width: 768px) { - html { - font-size: 16px; - } + html { + font-size: 16px; + } } .btn:focus, .btn:active:focus, .btn-link.nav-link:focus, .form-control:focus, .form-check-input:focus { - box-shadow: 0 0 0 0.1rem white, 0 0 0 0.25rem #258cfb; + box-shadow: 0 0 0 0.1rem white, 0 0 0 0.25rem #258cfb; +} + +#scroll-container { + flex: 1; + overflow-y: scroll; +} + +#output-container .content { + white-space: break-spaces; } +.slider-container > .slider { + width: 100%; +} + +.slider-container > label { + width: 50px; + text-align: center; +} diff --git a/LLama.Web/wwwroot/js/sessionConnectionChat.js b/LLama.Web/wwwroot/js/sessionConnectionChat.js index 472b5971..719c44ac 100644 --- a/LLama.Web/wwwroot/js/sessionConnectionChat.js +++ b/LLama.Web/wwwroot/js/sessionConnectionChat.js @@ -1,26 +1,26 @@ -const createConnectionSessionChat = (LLamaExecutorType) => { +const createConnectionSessionChat = () => { const outputErrorTemplate = $("#outputErrorTemplate").html(); const outputInfoTemplate = $("#outputInfoTemplate").html(); const outputUserTemplate = $("#outputUserTemplate").html(); const outputBotTemplate = $("#outputBotTemplate").html(); - const sessionDetailsTemplate = $("#sessionDetailsTemplate").html(); + const signatureTemplate = $("#signatureTemplate").html(); - let connectionId; + let inferenceSession; const connection = new signalR.HubConnectionBuilder().withUrl("/SessionConnectionHub").build(); const scrollContainer = $("#scroll-container"); const outputContainer = $("#output-container"); const chatInput = $("#input"); - const onStatus = (connection, status) => { - connectionId = connection; if (status == Enums.SessionConnectionStatus.Connected) { $("#socket").text("Connected").addClass("text-success"); } else if (status == Enums.SessionConnectionStatus.Loaded) { + loaderHide(); enableControls(); - $("#session-details").html(Mustache.render(sessionDetailsTemplate, { model: getSelectedModel(), prompt: getSelectedPrompt(), parameter: getSelectedParameter() })); + $("#load").hide(); + $("#unload").show(); onInfo(`New model session successfully started`) } } @@ -36,30 +36,31 @@ const createConnectionSessionChat = (LLamaExecutorType) => { let responseContent; let responseContainer; - let responseFirstFragment; + let responseFirstToken; const onResponse = (response) => { if (!response) return; - if (response.isFirst) { - outputContainer.append(Mustache.render(outputBotTemplate, response)); - responseContainer = $(`#${response.id}`); + if (response.tokenType == Enums.TokenType.Begin) { + const uniqueId = randomString(); + outputContainer.append(Mustache.render(outputBotTemplate, { id: uniqueId, ...response })); + responseContainer = $(`#${uniqueId}`); responseContent = responseContainer.find(".content"); - responseFirstFragment = true; + responseFirstToken = true; scrollToBottom(true); return; } - if (response.isLast) { + if (response.tokenType == Enums.TokenType.End || response.tokenType == Enums.TokenType.Cancel) { enableControls(); - responseContainer.find(".signature").append(response.content); + responseContainer.find(".signature").append(Mustache.render(signatureTemplate, response)); scrollToBottom(); } else { - if (responseFirstFragment) { + if (responseFirstToken) { responseContent.empty(); - responseFirstFragment = false; + responseFirstToken = false; responseContainer.find(".date").append(getDateTime()); } responseContent.append(response.content); @@ -67,45 +68,88 @@ const createConnectionSessionChat = (LLamaExecutorType) => { } } - const sendPrompt = async () => { const text = chatInput.val(); if (text) { + chatInput.val(null); disableControls(); outputContainer.append(Mustache.render(outputUserTemplate, { text: text, date: getDateTime() })); - await connection.invoke('SendPrompt', text); - chatInput.val(null); + inferenceSession = await connection + .stream("SendPrompt", text, serializeFormToJson('SessionParameters')) + .subscribe({ + next: onResponse, + complete: onResponse, + error: onError, + }); scrollToBottom(true); } } const cancelPrompt = async () => { - await ajaxPostJsonAsync('?handler=Cancel', { connectionId: connectionId }); + if (inferenceSession) + inferenceSession.dispose(); } const loadModel = async () => { - const modelName = getSelectedModel(); - const promptName = getSelectedPrompt(); - const parameterName = getSelectedParameter(); - if (!modelName || !promptName || !parameterName) { - onError("Please select a valid Model, Parameter and Prompt"); - return; - } + const sessionParams = serializeFormToJson('SessionParameters'); + loaderShow(); + disableControls(); + disablePromptControls(); + $("#load").attr("disabled", "disabled"); + // TODO: Split parameters sets + await connection.invoke('LoadModel', sessionParams, sessionParams); + } + + const unloadModel = async () => { disableControls(); - await connection.invoke('LoadModel', LLamaExecutorType, modelName, promptName, parameterName); + enablePromptControls(); + $("#load").removeAttr("disabled"); } + const serializeFormToJson = (form) => { + const formDataJson = {}; + const formData = new FormData(document.getElementById(form)); + formData.forEach((value, key) => { + + if (key.includes(".")) + key = key.split(".")[1]; + + // Convert number strings to numbers + if (!isNaN(value) && value.trim() !== "") { + formDataJson[key] = parseFloat(value); + } + // Convert boolean strings to booleans + else if (value === "true" || value === "false") { + formDataJson[key] = (value === "true"); + } + else { + formDataJson[key] = value; + } + }); + return formDataJson; + } const enableControls = () => { $(".input-control").removeAttr("disabled"); } - const disableControls = () => { $(".input-control").attr("disabled", "disabled"); } + const enablePromptControls = () => { + $("#load").show(); + $("#unload").hide(); + $(".prompt-control").removeAttr("disabled"); + activatePromptTab(); + } + + const disablePromptControls = () => { + $(".prompt-control").attr("disabled", "disabled"); + activateParamsTab(); + } + const clearOutput = () => { outputContainer.empty(); } @@ -117,27 +161,14 @@ const createConnectionSessionChat = (LLamaExecutorType) => { customPrompt.text(selectedValue); } - - const getSelectedModel = () => { - return $("option:selected", "#Model").val(); - } - - - const getSelectedParameter = () => { - return $("option:selected", "#Parameter").val(); - } - - - const getSelectedPrompt = () => { - return $("option:selected", "#Prompt").val(); - } - - const getDateTime = () => { const dateTime = new Date(); return dateTime.toLocaleString(); } + const randomString = () => { + return Math.random().toString(36).slice(2); + } const scrollToBottom = (force) => { const scrollTop = scrollContainer.scrollTop(); @@ -151,10 +182,25 @@ const createConnectionSessionChat = (LLamaExecutorType) => { } } + const activatePromptTab = () => { + $("#nav-prompt-tab").trigger("click"); + } + const activateParamsTab = () => { + $("#nav-params-tab").trigger("click"); + } + + const loaderShow = () => { + $(".spinner").show(); + } + + const loaderHide = () => { + $(".spinner").hide(); + } // Map UI functions $("#load").on("click", loadModel); + $("#unload").on("click", unloadModel); $("#send").on("click", sendPrompt); $("#clear").on("click", clearOutput); $("#cancel").on("click", cancelPrompt); @@ -165,7 +211,10 @@ const createConnectionSessionChat = (LLamaExecutorType) => { sendPrompt(); } }); - + $(".slider").on("input", function (e) { + const slider = $(this); + slider.next().text(slider.val()); + }).trigger("input"); // Map signalr functions diff --git a/LLama.Web/wwwroot/js/site.js b/LLama.Web/wwwroot/js/site.js index 2f679669..6612c772 100644 --- a/LLama.Web/wwwroot/js/site.js +++ b/LLama.Web/wwwroot/js/site.js @@ -40,11 +40,17 @@ const Enums = { Loaded: 4, Connected: 10 }), - LLamaExecutorType: Object.freeze({ + ExecutorType: Object.freeze({ Interactive: 0, Instruct: 1, Stateless: 2 }), + TokenType: Object.freeze({ + Begin: 0, + Content: 2, + End: 4, + Cancel: 10 + }), GetName: (enumType, enumKey) => { return Object.keys(enumType)[enumKey] }, From e2a17d6b6f0490cfbe3d88b66d1be4eab58daaa1 Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Wed, 4 Oct 2023 13:35:18 +1300 Subject: [PATCH 3/7] Refactor conflicting object name SessionOptions --- LLama.Web/Common/ISessionConfig.cs | 13 ++++++++ .../{SessionOptions.cs => SessionConfig.cs} | 2 +- LLama.Web/{Extensioms.cs => Extensions.cs} | 6 ++-- LLama.Web/Hubs/SessionConnectionHub.cs | 2 +- LLama.Web/Models/ModelSession.cs | 30 +++++++++---------- LLama.Web/Pages/Index.cshtml | 10 +++---- LLama.Web/Pages/Index.cshtml.cs | 4 +-- LLama.Web/Services/IModelSessionService.cs | 4 +-- LLama.Web/Services/ModelSessionService.cs | 2 +- 9 files changed, 43 insertions(+), 30 deletions(-) create mode 100644 LLama.Web/Common/ISessionConfig.cs rename LLama.Web/Common/{SessionOptions.cs => SessionConfig.cs} (89%) rename LLama.Web/{Extensioms.cs => Extensions.cs} (88%) diff --git a/LLama.Web/Common/ISessionConfig.cs b/LLama.Web/Common/ISessionConfig.cs new file mode 100644 index 00000000..09bddc2d --- /dev/null +++ b/LLama.Web/Common/ISessionConfig.cs @@ -0,0 +1,13 @@ +namespace LLama.Web.Common +{ + public interface ISessionConfig + { + string AntiPrompt { get; set; } + List AntiPrompts { get; set; } + LLamaExecutorType ExecutorType { get; set; } + string Model { get; set; } + string OutputFilter { get; set; } + List OutputFilters { get; set; } + string Prompt { get; set; } + } +} \ No newline at end of file diff --git a/LLama.Web/Common/SessionOptions.cs b/LLama.Web/Common/SessionConfig.cs similarity index 89% rename from LLama.Web/Common/SessionOptions.cs rename to LLama.Web/Common/SessionConfig.cs index 34386955..f0a2d22b 100644 --- a/LLama.Web/Common/SessionOptions.cs +++ b/LLama.Web/Common/SessionConfig.cs @@ -1,6 +1,6 @@ namespace LLama.Web.Common { - public class SessionOptions + public class SessionConfig : ISessionConfig { public string Model { get; set; } public string Prompt { get; set; } diff --git a/LLama.Web/Extensioms.cs b/LLama.Web/Extensions.cs similarity index 88% rename from LLama.Web/Extensioms.cs rename to LLama.Web/Extensions.cs index 50bb55c4..99f745dd 100644 --- a/LLama.Web/Extensioms.cs +++ b/LLama.Web/Extensions.cs @@ -2,14 +2,14 @@ namespace LLama.Web { - public static class Extensioms + public static class Extensions { /// /// Combines the AntiPrompts list and AntiPrompt csv /// /// The session configuration. /// Combined AntiPrompts with duplicates removed - public static List GetAntiPrompts(this Common.SessionOptions sessionConfig) + public static List GetAntiPrompts(this ISessionConfig sessionConfig) { return CombineCSV(sessionConfig.AntiPrompts, sessionConfig.AntiPrompt); } @@ -19,7 +19,7 @@ namespace LLama.Web ///
/// The session configuration. /// Combined OutputFilters with duplicates removed - public static List GetOutputFilters(this Common.SessionOptions sessionConfig) + public static List GetOutputFilters(this ISessionConfig sessionConfig) { return CombineCSV(sessionConfig.OutputFilters, sessionConfig.OutputFilter); } diff --git a/LLama.Web/Hubs/SessionConnectionHub.cs b/LLama.Web/Hubs/SessionConnectionHub.cs index 730d4e87..24457683 100644 --- a/LLama.Web/Hubs/SessionConnectionHub.cs +++ b/LLama.Web/Hubs/SessionConnectionHub.cs @@ -37,7 +37,7 @@ namespace LLama.Web.Hubs [HubMethodName("LoadModel")] - public async Task OnLoadModel(Common.SessionOptions sessionConfig, InferenceOptions inferenceConfig) + public async Task OnLoadModel(ISessionConfig sessionConfig, InferenceOptions inferenceConfig) { _logger.Log(LogLevel.Information, "[OnLoadModel] - Load new model, Connection: {0}", Context.ConnectionId); await _modelSessionService.CloseAsync(Context.ConnectionId); diff --git a/LLama.Web/Models/ModelSession.cs b/LLama.Web/Models/ModelSession.cs index 35413f92..91c8920f 100644 --- a/LLama.Web/Models/ModelSession.cs +++ b/LLama.Web/Models/ModelSession.cs @@ -9,21 +9,21 @@ namespace LLama.Web.Models private readonly LLamaModel _model; private readonly LLamaContext _context; private readonly ILLamaExecutor _executor; - private readonly Common.SessionOptions _sessionParams; + private readonly ISessionConfig _sessionConfig; private readonly ITextStreamTransform _outputTransform; private readonly InferenceOptions _defaultInferenceConfig; private CancellationTokenSource _cancellationTokenSource; - public ModelSession(LLamaModel model, LLamaContext context, string sessionId, Common.SessionOptions sessionOptions, InferenceOptions inferenceOptions = null) + public ModelSession(LLamaModel model, LLamaContext context, string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceOptions = null) { _model = model; _context = context; _sessionId = sessionId; - _sessionParams = sessionOptions; + _sessionConfig = sessionConfig; _defaultInferenceConfig = inferenceOptions ?? new InferenceOptions(); - _outputTransform = CreateOutputFilter(_sessionParams); - _executor = CreateExecutor(_model, _context, _sessionParams); + _outputTransform = CreateOutputFilter(); + _executor = CreateExecutor(); } /// @@ -34,7 +34,7 @@ namespace LLama.Web.Models /// /// Gets the name of the model. /// - public string ModelName => _sessionParams.Model; + public string ModelName => _sessionConfig.Model; /// /// Gets the context. @@ -44,7 +44,7 @@ namespace LLama.Web.Models /// /// Gets the session configuration. /// - public Common.SessionOptions SessionConfig => _sessionParams; + public ISessionConfig SessionConfig => _sessionConfig; /// /// Gets the inference parameters. @@ -60,16 +60,16 @@ namespace LLama.Web.Models /// The cancellation token. internal async Task InitializePrompt(InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) { - if (_sessionParams.ExecutorType == LLamaExecutorType.Stateless) + if (_sessionConfig.ExecutorType == LLamaExecutorType.Stateless) return; - if (string.IsNullOrEmpty(_sessionParams.Prompt)) + if (string.IsNullOrEmpty(_sessionConfig.Prompt)) return; // Run Initial prompt var inferenceParams = ConfigureInferenceParams(inferenceConfig); _cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - await foreach (var _ in _executor.InferAsync(_sessionParams.Prompt, inferenceParams, _cancellationTokenSource.Token)) + await foreach (var _ in _executor.InferAsync(_sessionConfig.Prompt, inferenceParams, _cancellationTokenSource.Token)) { // We dont really need the response of the initial prompt, so exit on first token break; @@ -114,13 +114,13 @@ namespace LLama.Web.Models private IInferenceParams ConfigureInferenceParams(InferenceOptions inferenceConfig) { var inferenceParams = inferenceConfig ?? _defaultInferenceConfig; - inferenceParams.AntiPrompts = _sessionParams.GetAntiPrompts(); + inferenceParams.AntiPrompts = _sessionConfig.GetAntiPrompts(); return inferenceParams; } - private ITextStreamTransform CreateOutputFilter(Common.SessionOptions sessionConfig) + private ITextStreamTransform CreateOutputFilter() { - var outputFilters = sessionConfig.GetOutputFilters(); + var outputFilters = _sessionConfig.GetOutputFilters(); if (outputFilters.Count > 0) return new LLamaTransforms.KeywordTextOutputStreamTransform(outputFilters); @@ -128,9 +128,9 @@ namespace LLama.Web.Models } - private ILLamaExecutor CreateExecutor(LLamaModel model, LLamaContext context, Common.SessionOptions sessionConfig) + private ILLamaExecutor CreateExecutor() { - return sessionConfig.ExecutorType switch + return _sessionConfig.ExecutorType switch { LLamaExecutorType.Interactive => new InteractiveExecutor(_context), LLamaExecutorType.Instruct => new InstructExecutor(_context), diff --git a/LLama.Web/Pages/Index.cshtml b/LLama.Web/Pages/Index.cshtml index 55512603..3df4b699 100644 --- a/LLama.Web/Pages/Index.cshtml +++ b/LLama.Web/Pages/Index.cshtml @@ -24,11 +24,11 @@
Model - @Html.DropDownListFor(m => m.SessionOptions.Model, new SelectList(Model.Options.Models, "Name", "Name"), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"}) + @Html.DropDownListFor(m => m.SessionConfig.Model, new SelectList(Model.Options.Models, "Name", "Name"), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"})
Inference Type - @Html.DropDownListFor(m => m.SessionOptions.ExecutorType, Html.GetEnumSelectList(), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"}) + @Html.DropDownListFor(m => m.SessionConfig.ExecutorType, Html.GetEnumSelectList(), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"})
/// The session identifier. - /// The session configuration. + /// The session configuration. /// The default inference configuration, will be used for all inference where no infer configuration is supplied. /// The cancellation token. /// @@ -33,7 +33,7 @@ namespace LLama.Web.Services /// or /// Failed to create model session /// - Task CreateAsync(string sessionId, Common.SessionOptions sessionOptions, InferenceOptions inferenceOptions = null, CancellationToken cancellationToken = default); + Task CreateAsync(string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceOptions = null, CancellationToken cancellationToken = default); /// diff --git a/LLama.Web/Services/ModelSessionService.cs b/LLama.Web/Services/ModelSessionService.cs index e808e630..84070d94 100644 --- a/LLama.Web/Services/ModelSessionService.cs +++ b/LLama.Web/Services/ModelSessionService.cs @@ -65,7 +65,7 @@ namespace LLama.Web.Services /// or /// Failed to create model session /// - public async Task CreateAsync(string sessionId, Common.SessionOptions sessionConfig, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) + public async Task CreateAsync(string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) { if (_modelSessions.TryGetValue(sessionId, out _)) throw new Exception($"Session with id {sessionId} already exists"); From 9b8de007dc5e26ac425d916c15191907580a8b54 Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Wed, 4 Oct 2023 13:47:08 +1300 Subject: [PATCH 4/7] Propagate ILogger --- LLama.Examples/NewVersion/CodingAssistant.cs | 2 +- LLama.Web/Models/LLamaModel.cs | 6 ++++-- LLama.Web/Services/ModelService.cs | 6 ++++-- LLama/LLamaInstructExecutor.cs | 6 ++++-- LLama/LLamaInteractExecutor.cs | 4 +++- LLama/LLamaStatelessExecutor.cs | 6 +++++- LLama/LLamaWeights.cs | 6 ++++-- 7 files changed, 25 insertions(+), 11 deletions(-) diff --git a/LLama.Examples/NewVersion/CodingAssistant.cs b/LLama.Examples/NewVersion/CodingAssistant.cs index 69e997d3..9108e01d 100644 --- a/LLama.Examples/NewVersion/CodingAssistant.cs +++ b/LLama.Examples/NewVersion/CodingAssistant.cs @@ -31,7 +31,7 @@ }; using var model = LLamaWeights.LoadFromFile(parameters); using var context = model.CreateContext(parameters); - var executor = new InstructExecutor(context, InstructionPrefix, InstructionSuffix); + var executor = new InstructExecutor(context, null!, InstructionPrefix, InstructionSuffix); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example, the LLM will follow your instructions." + diff --git a/LLama.Web/Models/LLamaModel.cs b/LLama.Web/Models/LLamaModel.cs index 71bb290e..5aedc5f5 100644 --- a/LLama.Web/Models/LLamaModel.cs +++ b/LLama.Web/Models/LLamaModel.cs @@ -10,6 +10,7 @@ namespace LLama.Web.Models /// public class LLamaModel : IDisposable { + private readonly ILogger _llamaLogger; private readonly ModelOptions _config; private readonly LLamaWeights _weights; private readonly ConcurrentDictionary _contexts; @@ -18,9 +19,10 @@ namespace LLama.Web.Models /// Initializes a new instance of the class. /// /// The model parameters. - public LLamaModel(ModelOptions modelParams) + public LLamaModel(ModelOptions modelParams, ILogger llamaLogger) { _config = modelParams; + _llamaLogger = llamaLogger; _weights = LLamaWeights.LoadFromFile(modelParams); _contexts = new ConcurrentDictionary(); } @@ -56,7 +58,7 @@ namespace LLama.Web.Models if (_config.MaxInstances > -1 && ContextCount >= _config.MaxInstances) throw new Exception($"Maximum model instances reached"); - context = _weights.CreateContext(_config); + context = _weights.CreateContext(_config, _llamaLogger); if (_contexts.TryAdd(contextName, context)) return Task.FromResult(context); diff --git a/LLama.Web/Services/ModelService.cs b/LLama.Web/Services/ModelService.cs index 2a3d4788..dfb34bb6 100644 --- a/LLama.Web/Services/ModelService.cs +++ b/LLama.Web/Services/ModelService.cs @@ -11,6 +11,7 @@ namespace LLama.Web.Services ///
public class ModelService : IModelService { + private readonly ILogger _llamaLogger; private readonly AsyncLock _modelLock; private readonly AsyncLock _contextLock; private readonly LLamaOptions _configuration; @@ -22,8 +23,9 @@ namespace LLama.Web.Services ///
/// The logger. /// The options. - public ModelService(LLamaOptions configuration) + public ModelService(LLamaOptions configuration, ILogger llamaLogger) { + _llamaLogger = llamaLogger; _modelLock = new AsyncLock(); _contextLock = new AsyncLock(); _configuration = configuration; @@ -52,7 +54,7 @@ namespace LLama.Web.Services await UnloadModels(); - model = new LLamaModel(modelOptions); + model = new LLamaModel(modelOptions, _llamaLogger); _modelInstances.TryAdd(modelOptions.Name, model); return model; } diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 6faa3db2..dab34106 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -9,6 +9,7 @@ using System.Text.Json; using System.Text.Json.Serialization; using System.Threading.Tasks; using LLama.Extensions; +using Microsoft.Extensions.Logging; namespace LLama { @@ -27,10 +28,11 @@ namespace LLama /// /// /// + /// /// /// - public InstructExecutor(LLamaContext context, string instructionPrefix = "\n\n### Instruction:\n\n", - string instructionSuffix = "\n\n### Response:\n\n") : base(context) + public InstructExecutor(LLamaContext context, ILogger logger = null!, string instructionPrefix = "\n\n### Instruction:\n\n", + string instructionSuffix = "\n\n### Response:\n\n") : base(context, logger) { _inp_pfx = Context.Tokenize(instructionPrefix, true); _inp_sfx = Context.Tokenize(instructionSuffix, false); diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index ab403212..0f374e09 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -9,6 +9,7 @@ using System.Text.Json; using System.Text.Json.Serialization; using System.Threading.Tasks; using LLama.Extensions; +using Microsoft.Extensions.Logging; namespace LLama { @@ -25,7 +26,8 @@ namespace LLama /// /// /// - public InteractiveExecutor(LLamaContext context) : base(context) + /// + public InteractiveExecutor(LLamaContext context, ILogger logger = null!) : base(context, logger) { _llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle); } diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 3ff755a0..e5348bb4 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -7,6 +7,7 @@ using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using LLama.Extensions; +using Microsoft.Extensions.Logging; namespace LLama { @@ -19,6 +20,7 @@ namespace LLama public class StatelessExecutor : ILLamaExecutor { + private readonly ILogger? _logger; private readonly LLamaWeights _weights; private readonly IModelParams _params; @@ -32,8 +34,10 @@ namespace LLama /// /// /// - public StatelessExecutor(LLamaWeights weights, IModelParams @params) + /// + public StatelessExecutor(LLamaWeights weights, IModelParams @params, ILogger logger = null!) { + _logger = logger; _weights = weights; _params = @params; diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index 1b067f1b..d841d5a9 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -3,6 +3,7 @@ using System.Text; using LLama.Abstractions; using LLama.Extensions; using LLama.Native; +using Microsoft.Extensions.Logging; namespace LLama { @@ -72,10 +73,11 @@ namespace LLama /// Create a llama_context using this model /// /// + /// /// - public LLamaContext CreateContext(IModelParams @params) + public LLamaContext CreateContext(IModelParams @params, ILogger logger = default!) { - return new LLamaContext(this, @params); + return new LLamaContext(this, @params, logger); } } } From a8a498dc12c0c74c837e4a551c05c38e4c63aca5 Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Wed, 4 Oct 2023 16:32:13 +1300 Subject: [PATCH 5/7] Fix up issues found during testing --- LLama.Web/Extensions.cs | 2 +- LLama.Web/Hubs/SessionConnectionHub.cs | 2 +- LLama.Web/Pages/Shared/_ChatTemplates.cshtml | 2 +- LLama.Web/Pages/Shared/_Parameters.cshtml | 1 - LLama.Web/Program.cs | 2 ++ LLama.Web/Services/ModelService.cs | 7 ++++--- LLama.Web/appsettings.json | 6 +++--- LLama.Web/wwwroot/js/sessionConnectionChat.js | 5 +++-- 8 files changed, 15 insertions(+), 12 deletions(-) diff --git a/LLama.Web/Extensions.cs b/LLama.Web/Extensions.cs index 99f745dd..ee8d7f7f 100644 --- a/LLama.Web/Extensions.cs +++ b/LLama.Web/Extensions.cs @@ -33,7 +33,7 @@ namespace LLama.Web /// Combined list with duplicates removed private static List CombineCSV(List list, string csv) { - var results = list?.Count == 0 + var results = list is null || list.Count == 0 ? CommaSeperatedToList(csv) : CommaSeperatedToList(csv).Concat(list); return results diff --git a/LLama.Web/Hubs/SessionConnectionHub.cs b/LLama.Web/Hubs/SessionConnectionHub.cs index 24457683..966ec8a4 100644 --- a/LLama.Web/Hubs/SessionConnectionHub.cs +++ b/LLama.Web/Hubs/SessionConnectionHub.cs @@ -37,7 +37,7 @@ namespace LLama.Web.Hubs [HubMethodName("LoadModel")] - public async Task OnLoadModel(ISessionConfig sessionConfig, InferenceOptions inferenceConfig) + public async Task OnLoadModel(SessionConfig sessionConfig, InferenceOptions inferenceConfig) { _logger.Log(LogLevel.Information, "[OnLoadModel] - Load new model, Connection: {0}", Context.ConnectionId); await _modelSessionService.CloseAsync(Context.ConnectionId); diff --git a/LLama.Web/Pages/Shared/_ChatTemplates.cshtml b/LLama.Web/Pages/Shared/_ChatTemplates.cshtml index cd768f1f..624f5859 100644 --- a/LLama.Web/Pages/Shared/_ChatTemplates.cshtml +++ b/LLama.Web/Pages/Shared/_ChatTemplates.cshtml @@ -25,7 +25,7 @@
-
+
diff --git a/LLama.Web/Pages/Shared/_Parameters.cshtml b/LLama.Web/Pages/Shared/_Parameters.cshtml index d6e476c4..76f3e321 100644 --- a/LLama.Web/Pages/Shared/_Parameters.cshtml +++ b/LLama.Web/Pages/Shared/_Parameters.cshtml @@ -1,7 +1,6 @@ @page @using LLama.Common; @model LLama.Abstractions.IInferenceParams -}
diff --git a/LLama.Web/Program.cs b/LLama.Web/Program.cs index 7c4583d2..193090d0 100644 --- a/LLama.Web/Program.cs +++ b/LLama.Web/Program.cs @@ -14,6 +14,8 @@ namespace LLama.Web // Add services to the container. builder.Services.AddRazorPages(); builder.Services.AddSignalR(); + builder.Logging.ClearProviders(); + builder.Services.AddLogging((loggingBuilder) => loggingBuilder.SetMinimumLevel(LogLevel.Trace).AddConsole()); // Load InteractiveOptions builder.Services.AddOptions() diff --git a/LLama.Web/Services/ModelService.cs b/LLama.Web/Services/ModelService.cs index dfb34bb6..3634f6ab 100644 --- a/LLama.Web/Services/ModelService.cs +++ b/LLama.Web/Services/ModelService.cs @@ -1,6 +1,7 @@ using LLama.Web.Async; using LLama.Web.Common; using LLama.Web.Models; +using Microsoft.Extensions.Options; using System.Collections.Concurrent; namespace LLama.Web.Services @@ -11,10 +12,10 @@ namespace LLama.Web.Services /// public class ModelService : IModelService { - private readonly ILogger _llamaLogger; private readonly AsyncLock _modelLock; private readonly AsyncLock _contextLock; private readonly LLamaOptions _configuration; + private readonly ILogger _llamaLogger; private readonly ConcurrentDictionary _modelInstances; @@ -23,12 +24,12 @@ namespace LLama.Web.Services /// /// The logger. /// The options. - public ModelService(LLamaOptions configuration, ILogger llamaLogger) + public ModelService(IOptions configuration, ILogger llamaLogger) { _llamaLogger = llamaLogger; _modelLock = new AsyncLock(); _contextLock = new AsyncLock(); - _configuration = configuration; + _configuration = configuration.Value; _modelInstances = new ConcurrentDictionary(); } diff --git a/LLama.Web/appsettings.json b/LLama.Web/appsettings.json index 6231b882..82d62b1a 100644 --- a/LLama.Web/appsettings.json +++ b/LLama.Web/appsettings.json @@ -7,12 +7,12 @@ }, "AllowedHosts": "*", "LLamaOptions": { - "ModelLoadType": "Single", + "ModelLoadType": 0, "Models": [ { - "Name": "WizardLM-7B", + "Name": "LLama2-7b-Chat", "MaxInstances": 20, - "ModelPath": "D:\\Repositories\\AI\\Models\\wizardLM-7B.ggmlv3.q4_0.bin", + "ModelPath": "..\\LLama.Unittest\\Models\\llama-2-7b-chat.Q4_0.gguf", "ContextSize": 2048, "BatchSize": 2048, "Threads": 4, diff --git a/LLama.Web/wwwroot/js/sessionConnectionChat.js b/LLama.Web/wwwroot/js/sessionConnectionChat.js index 719c44ac..24821150 100644 --- a/LLama.Web/wwwroot/js/sessionConnectionChat.js +++ b/LLama.Web/wwwroot/js/sessionConnectionChat.js @@ -43,8 +43,8 @@ const createConnectionSessionChat = () => { return; if (response.tokenType == Enums.TokenType.Begin) { - const uniqueId = randomString(); - outputContainer.append(Mustache.render(outputBotTemplate, { id: uniqueId, ...response })); + let uniqueId = randomString(); + outputContainer.append(Mustache.render(outputBotTemplate, { uniqueId: uniqueId, ...response })); responseContainer = $(`#${uniqueId}`); responseContent = responseContainer.find(".content"); responseFirstToken = true; @@ -102,6 +102,7 @@ const createConnectionSessionChat = () => { } const unloadModel = async () => { + await cancelPrompt(); disableControls(); enablePromptControls(); $("#load").removeAttr("disabled"); From 4ec9aed47a0fe4bc6a2ffc14addf429ce8599846 Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Fri, 20 Oct 2023 08:29:26 +1300 Subject: [PATCH 6/7] Revert LLamasSharp project changes --- LLama/LLamaExecutorBase.cs | 3 +-- LLama/LLamaInstructExecutor.cs | 2 +- LLama/LLamaInteractExecutor.cs | 2 +- LLama/LLamaStatelessExecutor.cs | 3 +-- LLama/LLamaWeights.cs | 5 ++--- 5 files changed, 6 insertions(+), 9 deletions(-) diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index 242ae10b..0c8e4679 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -75,10 +75,9 @@ namespace LLama /// /// /// - protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null) + protected StatefulExecutorBase(LLamaContext context) { Context = context; - _logger = logger; _pastTokensCount = 0; _consumedTokensCount = 0; _n_session_consumed = 0; diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index a4e7c0fd..c7cb55fe 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -32,7 +32,7 @@ namespace LLama /// /// public InstructExecutor(LLamaContext context, ILogger logger = null!, string instructionPrefix = "\n\n### Instruction:\n\n", - string instructionSuffix = "\n\n### Response:\n\n") : base(context, logger) + string instructionSuffix = "\n\n### Response:\n\n") : base(context) { _inp_pfx = Context.Tokenize(instructionPrefix, true); _inp_sfx = Context.Tokenize(instructionSuffix, false); diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 0f374e09..8247ca10 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -27,7 +27,7 @@ namespace LLama /// /// /// - public InteractiveExecutor(LLamaContext context, ILogger logger = null!) : base(context, logger) + public InteractiveExecutor(LLamaContext context) : base(context) { _llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle); } diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 08a78f9e..d1b73c2f 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -36,9 +36,8 @@ namespace LLama /// /// /// - public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger logger = null!) + public StatelessExecutor(LLamaWeights weights, IContextParams @params) { - _logger = logger; _weights = weights; _params = @params; diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index 76d46d25..5dc2024d 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -81,11 +81,10 @@ namespace LLama /// Create a llama_context using this model /// /// - /// /// - public LLamaContext CreateContext(IContextParams @params, ILogger logger = default!) + public LLamaContext CreateContext(IContextParams @params) { - return new LLamaContext(this, @params, logger); + return new LLamaContext(this, @params); } } } From 952e77f97b5cba997cd3439e44ad247a0e83dab8 Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Fri, 20 Oct 2023 08:33:27 +1300 Subject: [PATCH 7/7] Remove old parameter --- LLama.Web/Models/LLamaModel.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LLama.Web/Models/LLamaModel.cs b/LLama.Web/Models/LLamaModel.cs index 5aedc5f5..61341d42 100644 --- a/LLama.Web/Models/LLamaModel.cs +++ b/LLama.Web/Models/LLamaModel.cs @@ -58,7 +58,7 @@ namespace LLama.Web.Models if (_config.MaxInstances > -1 && ContextCount >= _config.MaxInstances) throw new Exception($"Maximum model instances reached"); - context = _weights.CreateContext(_config, _llamaLogger); + context = _weights.CreateContext(_config); if (_contexts.TryAdd(contextName, context)) return Task.FromResult(context);