diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 8eb2b9aa..f0817f8f 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -74,8 +74,7 @@ namespace LLama _ctx = Utils.InitLLamaContextFromModelParams(Params); } - //todo make this internal - public LLamaContext(SafeLLamaContextHandle nativeContext, IModelParams @params, Encoding encoding, ILLamaLogger? logger = null) + internal LLamaContext(SafeLLamaContextHandle nativeContext, IModelParams @params, Encoding encoding, ILLamaLogger? logger = null) { Params = @params; @@ -86,6 +85,9 @@ namespace LLama public LLamaContext(LLamaWeights model, IModelParams @params, Encoding encoding, ILLamaLogger? logger = null) { + if (model.NativeHandle.IsClosed) + throw new ObjectDisposedException("Cannot create context, model weights have been disposed"); + Params = @params; _logger = logger; diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index be21c6f5..8226753f 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -1,6 +1,6 @@ using System; using System.Text; -using LLama.Common; +using LLama.Abstractions; using LLama.Extensions; using LLama.Native; @@ -30,10 +30,14 @@ namespace LLama /// /// /// - public static LLamaWeights LoadFromFile(ModelParams @params) + public static LLamaWeights LoadFromFile(IModelParams @params) { using var pin = @params.ToLlamaContextParams(out var lparams); var weights = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); + + if (!string.IsNullOrEmpty(@params.LoraAdapter)) + weights.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads); + return new LLamaWeights(weights); } @@ -47,11 +51,11 @@ namespace LLama /// Create a llama_context using this model /// /// - /// + /// /// - public LLamaContext CreateContext(ModelParams @params, Encoding utf8) + public LLamaContext CreateContext(IModelParams @params, Encoding encoding) { - return new LLamaContext(this, @params, Encoding.UTF8); + return new LLamaContext(this, @params, encoding); } } } diff --git a/LLama/Utils.cs b/LLama/Utils.cs index 42172737..9f4fb3fa 100644 --- a/LLama/Utils.cs +++ b/LLama/Utils.cs @@ -13,16 +13,10 @@ namespace LLama { public static SafeLLamaContextHandle InitLLamaContextFromModelParams(IModelParams @params) { - using (@params.ToLlamaContextParams(out var lparams)) - { - var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); - var ctx = SafeLLamaContextHandle.Create(model, lparams); - - if (!string.IsNullOrEmpty(@params.LoraAdapter)) - model.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads); + using var weights = LLamaWeights.LoadFromFile(@params); - return ctx; - } + using (@params.ToLlamaContextParams(out var lparams)) + return SafeLLamaContextHandle.Create(weights.NativeHandle, lparams); } [Obsolete("Use SafeLLamaContextHandle Tokenize method instead")]