diff --git a/LLama.Examples/NewVersion/TalkToYourself.cs b/LLama.Examples/NewVersion/TalkToYourself.cs index 4a61f88f..35a65241 100644 --- a/LLama.Examples/NewVersion/TalkToYourself.cs +++ b/LLama.Examples/NewVersion/TalkToYourself.cs @@ -2,8 +2,6 @@ using System.Text; using LLama.Abstractions; using LLama.Common; -using LLama.Extensions; -using LLama.Native; namespace LLama.Examples.NewVersion { @@ -12,26 +10,20 @@ namespace LLama.Examples.NewVersion public static async Task Run() { Console.Write("Please input your model path: "); - string modelPath = "C:\\Users\\Martin\\Documents\\Python\\oobabooga_windows\\text-generation-webui\\models\\llama-2-7b-chat.ggmlv3.q6_K.bin"; + var modelPath = Console.ReadLine(); - // todo: model path is passed here, but isn't needed + // Load weights into memory var @params = new ModelParams(modelPath) { Seed = RandomNumberGenerator.GetInt32(int.MaxValue) }; + using var weights = LLamaWeights.LoadFromFile(@params); - // todo: all this pin stuff is ugly and should be hidden in the higher level wrapper - using var pin = @params.ToLlamaContextParams(out var lparams); - - // todo: we need a higher level wrapper around the model weights (LLamaWeights??) - var weights = SafeLlamaModelHandle.LoadFromFile(modelPath, lparams); - - // todo: need a method on the LLamaWeights which does this - var ctx1 = new LLamaContext(weights.CreateContext(lparams), @params, Encoding.UTF8); - var ctx2 = new LLamaContext(weights.CreateContext(lparams), @params, Encoding.UTF8); - - var alice = new InteractiveExecutor(ctx1); - var bob = new InteractiveExecutor(ctx2); + // Create 2 contexts sharing the same weights + using var aliceCtx = weights.CreateContext(@params, Encoding.UTF8); + var alice = new InteractiveExecutor(aliceCtx); + using var bobCtx = weights.CreateContext(@params, Encoding.UTF8); + var bob = new InteractiveExecutor(bobCtx); // Initial alice prompt var alicePrompt = "Transcript of a dialog, where the Alice interacts a person named Bob. Alice is friendly, kind, honest and good at writing.\nAlice: Hello"; @@ -46,7 +38,9 @@ namespace LLama.Examples.NewVersion { aliceResponse = await Prompt(alice, ConsoleColor.Green, bobResponse, false, true); bobResponse = await Prompt(bob, ConsoleColor.Red, aliceResponse, false, true); - Thread.Sleep(1000); + + if (Console.KeyAvailable) + break; } } diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index c6b8749a..8eb2b9aa 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -84,6 +84,17 @@ namespace LLama _ctx = nativeContext; } + public LLamaContext(LLamaWeights model, IModelParams @params, Encoding encoding, ILLamaLogger? logger = null) + { + Params = @params; + + _logger = logger; + _encoding = encoding; + + using var pin = @params.ToLlamaContextParams(out var lparams); + _ctx = SafeLLamaContextHandle.Create(model.NativeHandle, lparams); + } + /// /// Create a copy of the current state of this context /// diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs new file mode 100644 index 00000000..be21c6f5 --- /dev/null +++ b/LLama/LLamaWeights.cs @@ -0,0 +1,57 @@ +using System; +using System.Text; +using LLama.Common; +using LLama.Extensions; +using LLama.Native; + +namespace LLama +{ + /// + /// A set of model weights, loaded into memory. + /// + public class LLamaWeights + : IDisposable + { + private readonly SafeLlamaModelHandle _weights; + + /// + /// The native handle, which is used in the native APIs + /// + /// Be careful how you use this! + public SafeLlamaModelHandle NativeHandle => _weights; + + private LLamaWeights(SafeLlamaModelHandle weights) + { + _weights = weights; + } + + /// + /// Load weights into memory + /// + /// + /// + public static LLamaWeights LoadFromFile(ModelParams @params) + { + using var pin = @params.ToLlamaContextParams(out var lparams); + var weights = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); + return new LLamaWeights(weights); + } + + /// + public void Dispose() + { + _weights.Dispose(); + } + + /// + /// Create a llama_context using this model + /// + /// + /// + /// + public LLamaContext CreateContext(ModelParams @params, Encoding utf8) + { + return new LLamaContext(this, @params, Encoding.UTF8); + } + } +}