diff --git a/LLama.Web/Async/AsyncLock.cs b/LLama.Web/Async/AsyncLock.cs
new file mode 100644
index 00000000..09ccb0f7
--- /dev/null
+++ b/LLama.Web/Async/AsyncLock.cs
@@ -0,0 +1,55 @@
+namespace LLama.Web.Async
+{
+ ///
+ /// Create an Async locking using statment
+ ///
+ public sealed class AsyncLock
+ {
+ private readonly SemaphoreSlim _semaphore;
+ private readonly Task _releaser;
+
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ public AsyncLock()
+ {
+ _semaphore = new SemaphoreSlim(1, 1);
+ _releaser = Task.FromResult((IDisposable)new Releaser(this));
+ }
+
+
+ ///
+ /// Locks the using statement asynchronously.
+ ///
+ ///
+ public Task LockAsync()
+ {
+ var wait = _semaphore.WaitAsync();
+ if (wait.IsCompleted)
+ return _releaser;
+
+ return wait.ContinueWith((_, state) => (IDisposable)state, _releaser.Result, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default);
+ }
+
+
+ ///
+ /// IDisposable wrapper class to release the lock on dispose
+ ///
+ ///
+ private sealed class Releaser : IDisposable
+ {
+ private readonly AsyncLock _lockToRelease;
+
+ internal Releaser(AsyncLock lockToRelease)
+ {
+ _lockToRelease = lockToRelease;
+ }
+
+ public void Dispose()
+ {
+ _lockToRelease._semaphore.Release();
+ }
+ }
+ }
+}
diff --git a/LLama.Web/Common/LLamaOptions.cs b/LLama.Web/Common/LLamaOptions.cs
index 1ac0d829..a64b9635 100644
--- a/LLama.Web/Common/LLamaOptions.cs
+++ b/LLama.Web/Common/LLamaOptions.cs
@@ -2,6 +2,7 @@
{
public class LLamaOptions
{
+ public ModelLoadType ModelLoadType { get; set; }
public List Models { get; set; }
public List Prompts { get; set; } = new List();
public List Parameters { get; set; } = new List();
diff --git a/LLama.Web/Common/ModelLoadType.cs b/LLama.Web/Common/ModelLoadType.cs
new file mode 100644
index 00000000..9e1c77b7
--- /dev/null
+++ b/LLama.Web/Common/ModelLoadType.cs
@@ -0,0 +1,30 @@
+namespace LLama.Web.Common
+{
+ ///
+ /// The type of model load caching to use
+ ///
+ public enum ModelLoadType
+ {
+
+ ///
+ /// Only one model will be loaded into memory at a time, any other models will be unloaded before the new one is loaded
+ ///
+ Single = 0,
+
+ ///
+ /// Multiple models will be loaded into memory, ensure you use the ModelConfigs to split the hardware resources
+ ///
+ Multiple = 1,
+
+ ///
+ /// The first model in the appsettings.json list will be preloaded into memory at app startup
+ ///
+ PreloadSingle = 2,
+
+
+ ///
+ /// All models in the appsettings.json list will be preloaded into memory at app startup, ensure you use the ModelConfigs to split the hardware resources
+ ///
+ PreloadMultiple = 3,
+ }
+}
diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs
index f06757e3..c6cf0988 100644
--- a/LLama.Web/Common/ModelOptions.cs
+++ b/LLama.Web/Common/ModelOptions.cs
@@ -3,105 +3,123 @@ using LLama.Abstractions;
namespace LLama.Web.Common
{
- public class ModelOptions
- : IModelParams
+ public class ModelOptions : IModelParams
{
-
+ ///
+ /// Model friendly name
+ ///
public string Name { get; set; }
+
+ ///
+ /// Max context insta=nces allowed per model
+ ///
public int MaxInstances { get; set; }
+ ///
+ /// Model context size (n_ctx)
+ ///
+ public int ContextSize { get; set; } = 512;
+
+ ///
+ /// the GPU that is used for scratch and small tensors
+ ///
+ public int MainGpu { get; set; } = 0;
+
+ ///
+ /// if true, reduce VRAM usage at the cost of performance
+ ///
+ public bool LowVram { get; set; } = false;
+
+ ///
+ /// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
+ ///
+ public int GpuLayerCount { get; set; } = 20;
+
+ ///
+ /// Seed for the random number generator (seed)
+ ///
+ public int Seed { get; set; } = 1686349486;
+
+ ///
+ /// Use f16 instead of f32 for memory kv (memory_f16)
+ ///
+ public bool UseFp16Memory { get; set; } = true;
+
+ ///
+ /// Use mmap for faster loads (use_mmap)
+ ///
+ public bool UseMemorymap { get; set; } = true;
+
+ ///
+ /// Use mlock to keep model in memory (use_mlock)
+ ///
+ public bool UseMemoryLock { get; set; } = false;
+
+ ///
+ /// Compute perplexity over the prompt (perplexity)
+ ///
+ public bool Perplexity { get; set; } = false;
+
+ ///
+ /// Model path (model)
+ ///
+ public string ModelPath { get; set; }
+
+ ///
+ /// model alias
+ ///
+ public string ModelAlias { get; set; } = "unknown";
- ///
- /// Model context size (n_ctx)
- ///
- public int ContextSize { get; set; } = 512;
- ///
- /// the GPU that is used for scratch and small tensors
- ///
- public int MainGpu { get; set; } = 0;
- ///
- /// if true, reduce VRAM usage at the cost of performance
- ///
- public bool LowVram { get; set; } = false;
- ///
- /// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
- ///
- public int GpuLayerCount { get; set; } = 20;
- ///
- /// Seed for the random number generator (seed)
- ///
- public int Seed { get; set; } = 1686349486;
- ///
- /// Use f16 instead of f32 for memory kv (memory_f16)
- ///
- public bool UseFp16Memory { get; set; } = true;
- ///
- /// Use mmap for faster loads (use_mmap)
- ///
- public bool UseMemorymap { get; set; } = true;
- ///
- /// Use mlock to keep model in memory (use_mlock)
- ///
- public bool UseMemoryLock { get; set; } = false;
- ///
- /// Compute perplexity over the prompt (perplexity)
- ///
- public bool Perplexity { get; set; } = false;
- ///
- /// Model path (model)
- ///
- public string ModelPath { get; set; }
- ///
- /// model alias
- ///
- public string ModelAlias { get; set; } = "unknown";
- ///
- /// lora adapter path (lora_adapter)
- ///
- public string LoraAdapter { get; set; } = string.Empty;
- ///
- /// base model path for the lora adapter (lora_base)
- ///
- public string LoraBase { get; set; } = string.Empty;
- ///
- /// Number of threads (-1 = autodetect) (n_threads)
- ///
- public int Threads { get; set; } = Math.Max(Environment.ProcessorCount / 2, 1);
- ///
- /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch)
- ///
- public int BatchSize { get; set; } = 512;
-
- ///
- /// Whether to convert eos to newline during the inference.
- ///
- public bool ConvertEosToNewLine { get; set; } = false;
-
- ///
- /// Whether to use embedding mode. (embedding) Note that if this is set to true,
- /// The LLamaModel won't produce text response anymore.
- ///
- public bool EmbeddingMode { get; set; } = false;
-
- ///
- /// how split tensors should be distributed across GPUs
- ///
- public float[] TensorSplits { get; set; }
-
- ///
- /// RoPE base frequency
- ///
- public float RopeFrequencyBase { get; set; } = 10000.0f;
-
- ///
- /// RoPE frequency scaling factor
- ///
- public float RopeFrequencyScale { get; set; } = 1.0f;
-
- ///
- /// Use experimental mul_mat_q kernels
- ///
- public bool MulMatQ { get; set; }
+ ///
+ /// lora adapter path (lora_adapter)
+ ///
+ public string LoraAdapter { get; set; } = string.Empty;
+
+ ///
+ /// base model path for the lora adapter (lora_base)
+ ///
+ public string LoraBase { get; set; } = string.Empty;
+
+ ///
+ /// Number of threads (-1 = autodetect) (n_threads)
+ ///
+ public int Threads { get; set; } = Math.Max(Environment.ProcessorCount / 2, 1);
+
+ ///
+ /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch)
+ ///
+ public int BatchSize { get; set; } = 512;
+
+ ///
+ /// Whether to convert eos to newline during the inference.
+ ///
+ public bool ConvertEosToNewLine { get; set; } = false;
+
+ ///
+ /// Whether to use embedding mode. (embedding) Note that if this is set to true,
+ /// The LLamaModel won't produce text response anymore.
+ ///
+ public bool EmbeddingMode { get; set; } = false;
+
+ ///
+ /// how split tensors should be distributed across GPUs
+ ///
+ public float[] TensorSplits { get; set; }
+
+ ///
+ /// RoPE base frequency
+ ///
+ public float RopeFrequencyBase { get; set; } = 10000.0f;
+
+ ///
+ /// RoPE frequency scaling factor
+ ///
+ public float RopeFrequencyScale { get; set; } = 1.0f;
+
+ ///
+ /// Use experimental mul_mat_q kernels
+ ///
+ public bool MulMatQ { get; set; }
///
/// The encoding to use for models
diff --git a/LLama.Web/LLamaModel.cs b/LLama.Web/LLamaModel.cs
new file mode 100644
index 00000000..e500ba04
--- /dev/null
+++ b/LLama.Web/LLamaModel.cs
@@ -0,0 +1,106 @@
+using LLama.Abstractions;
+using LLama.Web.Common;
+using System.Collections.Concurrent;
+
+namespace LLama.Web
+{
+ ///
+ /// Wrapper class for LLamaSharp LLamaWeights
+ ///
+ ///
+ public class LLamaModel : IDisposable
+ {
+ private readonly ModelOptions _config;
+ private readonly LLamaWeights _weights;
+ private readonly ConcurrentDictionary _contexts;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The model parameters.
+ public LLamaModel(ModelOptions modelParams)
+ {
+ _config = modelParams;
+ _weights = LLamaWeights.LoadFromFile(modelParams);
+ _contexts = new ConcurrentDictionary();
+ }
+
+ ///
+ /// Gets the model configuration.
+ ///
+ public IModelParams ModelParams => _config;
+
+ ///
+ /// Gets the LLamaWeights
+ ///
+ public LLamaWeights LLamaWeights => _weights;
+
+
+ ///
+ /// Gets the context count.
+ ///
+ public int ContextCount => _contexts.Count;
+
+
+ ///
+ /// Creates a new context session on this model
+ ///
+ /// The unique context identifier
+ /// LLamaModelContext for this LLamaModel
+ /// Context exists
+ public Task CreateContext(string contextName)
+ {
+ if (_contexts.TryGetValue(contextName, out var context))
+ throw new Exception($"Context with id {contextName} already exists.");
+
+ if (_config.MaxInstances > -1 && ContextCount >= _config.MaxInstances)
+ throw new Exception($"Maximum model instances reached");
+
+ context = _weights.CreateContext(_config);
+ if (_contexts.TryAdd(contextName, context))
+ return Task.FromResult(context);
+
+ return Task.FromResult(null);
+ }
+
+ ///
+ /// Get a contexts belonging to this model
+ ///
+ /// The unique context identifier
+ /// LLamaModelContext for this LLamaModel with the specified contextName
+ public Task GetContext(string contextName)
+ {
+ if (_contexts.TryGetValue(contextName, out var context))
+ return Task.FromResult(context);
+
+ return Task.FromResult(null);
+ }
+
+ ///
+ /// Remove a context from this model
+ ///
+ /// The unique context identifier
+ /// true if removed, otherwise false
+ public Task RemoveContext(string contextName)
+ {
+ if (!_contexts.TryRemove(contextName, out var context))
+ return Task.FromResult(false);
+
+ context?.Dispose();
+ return Task.FromResult(true);
+ }
+
+
+ ///
+ /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
+ ///
+ public void Dispose()
+ {
+ foreach (var context in _contexts.Values)
+ {
+ context?.Dispose();
+ }
+ _weights.Dispose();
+ }
+ }
+}
diff --git a/LLama.Web/Services/IModelService.cs b/LLama.Web/Services/IModelService.cs
new file mode 100644
index 00000000..0a98f8f4
--- /dev/null
+++ b/LLama.Web/Services/IModelService.cs
@@ -0,0 +1,76 @@
+using LLama.Web.Common;
+
+namespace LLama.Web.Services
+{
+ ///
+ /// Service for managing language Models
+ ///
+ public interface IModelService
+ {
+ ///
+ /// Gets the model with the specified name.
+ ///
+ /// Name of the model.
+ Task GetModel(string modelName);
+
+
+ ///
+ /// Loads a model from a ModelConfig object.
+ ///
+ /// The model configuration.
+ Task LoadModel(ModelOptions modelOptions);
+
+
+ ///
+ /// Loads all models found in appsettings.json
+ ///
+ Task LoadModels();
+
+
+ ///
+ /// Unloads the model with the specified name.
+ ///
+ /// Name of the model.
+ Task UnloadModel(string modelName);
+
+
+ ///
+ /// Unloads all models.
+ ///
+ Task UnloadModels();
+
+
+ ///
+ /// Gets a context with the specified identifier
+ ///
+ /// Name of the model.
+ /// The context identifier.
+ Task GetContext(string modelName, string contextName);
+
+
+ ///
+ /// Removes the context.
+ ///
+ /// Name of the model.
+ /// The context identifier.
+ Task RemoveContext(string modelName, string contextName);
+
+
+ ///
+ /// Creates a context.
+ ///
+ /// Name of the model.
+ /// The context identifier.
+ Task CreateContext(string modelName, string contextName);
+
+
+ ///
+ /// Gets the or create model and context.
+ /// This will load a model from disk if not already loaded, and also create the context
+ ///
+ /// Name of the model.
+ /// The context identifier.
+ /// Both loaded Model and Context
+ Task<(LLamaModel, LLamaContext)> GetOrCreateModelAndContext(string modelName, string contextName);
+ }
+}
\ No newline at end of file
diff --git a/LLama.Web/Services/ModelService.cs b/LLama.Web/Services/ModelService.cs
new file mode 100644
index 00000000..16365a5d
--- /dev/null
+++ b/LLama.Web/Services/ModelService.cs
@@ -0,0 +1,202 @@
+using LLama.Web.Async;
+using LLama.Web.Common;
+using System.Collections.Concurrent;
+
+namespace LLama.Web.Services
+{
+
+ ///
+ /// Sercive for handling Models,Weights & Contexts
+ ///
+ public class ModelService : IModelService
+ {
+ private readonly AsyncLock _modelLock;
+ private readonly AsyncLock _contextLock;
+ private readonly LLamaOptions _configuration;
+ private readonly ConcurrentDictionary _modelInstances;
+
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The logger.
+ /// The options.
+ public ModelService(LLamaOptions configuration)
+ {
+ _modelLock = new AsyncLock();
+ _contextLock = new AsyncLock();
+ _configuration = configuration;
+ _modelInstances = new ConcurrentDictionary();
+ }
+
+
+ ///
+ /// Loads a model with the provided configuration.
+ ///
+ /// The model configuration.
+ ///
+ public async Task LoadModel(ModelOptions modelOptions)
+ {
+ if (_modelInstances.TryGetValue(modelOptions.Name, out var existingModel))
+ return existingModel;
+
+ using (await _modelLock.LockAsync())
+ {
+ if (_modelInstances.TryGetValue(modelOptions.Name, out var model))
+ return model;
+
+ // If in single mode unload any other models
+ if (_configuration.ModelLoadType == ModelLoadType.Single
+ || _configuration.ModelLoadType == ModelLoadType.PreloadSingle)
+ await UnloadModels();
+
+
+ model = new LLamaModel(modelOptions);
+ _modelInstances.TryAdd(modelOptions.Name, model);
+ return model;
+ }
+ }
+
+
+ ///
+ /// Loads the models.
+ ///
+ public async Task LoadModels()
+ {
+ if (_configuration.ModelLoadType == ModelLoadType.Single
+ || _configuration.ModelLoadType == ModelLoadType.Multiple)
+ return;
+
+ foreach (var modelConfig in _configuration.Models)
+ {
+ await LoadModel(modelConfig);
+
+ //Only preload first model if in SinglePreload mode
+ if (_configuration.ModelLoadType == ModelLoadType.PreloadSingle)
+ break;
+ }
+ }
+
+
+ ///
+ /// Unloads the model.
+ ///
+ /// Name of the model.
+ ///
+ public Task UnloadModel(string modelName)
+ {
+ if (_modelInstances.TryRemove(modelName, out var model))
+ {
+ model?.Dispose();
+ return Task.FromResult(true);
+ }
+ return Task.FromResult(false);
+ }
+
+
+
+ ///
+ /// Unloads all models.
+ ///
+ public async Task UnloadModels()
+ {
+ foreach (var modelName in _modelInstances.Keys)
+ {
+ await UnloadModel(modelName);
+ }
+ }
+
+
+ ///
+ /// Gets a model ny name.
+ ///
+ /// Name of the model.
+ ///
+ public Task GetModel(string modelName)
+ {
+ _modelInstances.TryGetValue(modelName, out var model);
+ return Task.FromResult(model);
+ }
+
+
+ ///
+ /// Gets a context from the specified model.
+ ///
+ /// Name of the model.
+ /// The contextName.
+ ///
+ /// Model not found
+ public async Task GetContext(string modelName, string contextName)
+ {
+ if (!_modelInstances.TryGetValue(modelName, out var model))
+ throw new Exception("Model not found");
+
+ return await model.GetContext(contextName);
+ }
+
+
+ ///
+ /// Creates a context on the specified model.
+ ///
+ /// Name of the model.
+ /// The contextName.
+ ///
+ /// Model not found
+ public async Task CreateContext(string modelName, string contextName)
+ {
+ if (!_modelInstances.TryGetValue(modelName, out var model))
+ throw new Exception("Model not found");
+
+ using (await _contextLock.LockAsync())
+ {
+ return await model.CreateContext(contextName);
+ }
+ }
+
+
+ ///
+ /// Removes a context from the specified model.
+ ///
+ /// Name of the model.
+ /// The contextName.
+ ///
+ /// Model not found
+ public async Task RemoveContext(string modelName, string contextName)
+ {
+ if (!_modelInstances.TryGetValue(modelName, out var model))
+ throw new Exception("Model not found");
+
+ using (await _contextLock.LockAsync())
+ {
+ return await model.RemoveContext(contextName);
+ }
+ }
+
+
+ ///
+ /// Loads, Gets,Creates a Model and a Context
+ ///
+ /// Name of the model.
+ /// The contextName.
+ ///
+ /// Model option '{modelName}' not found
+ public async Task<(LLamaModel, LLamaContext)> GetOrCreateModelAndContext(string modelName, string contextName)
+ {
+ if (_modelInstances.TryGetValue(modelName, out var model))
+ return (model, await model.GetContext(contextName) ?? await model.CreateContext(contextName));
+
+
+ // Get model configuration
+ var modelConfig = _configuration.Models.FirstOrDefault(x => x.Name == modelName);
+ if (modelConfig is null)
+ throw new Exception($"Model option '{modelName}' not found");
+
+ // Load Model
+ model = await LoadModel(modelConfig);
+
+ // Get or Create Context
+ return (model, await model.GetContext(contextName) ?? await model.CreateContext(contextName));
+ }
+
+ }
+}