Browse Source

Added a higher level `LLamaWeights` wrapper around `SafeLlamaModelHandle`

tags/v0.5.1
Martin Evans 2 years ago
parent
commit
e2fe08a9a2
3 changed files with 79 additions and 17 deletions
  1. +11
    -17
      LLama.Examples/NewVersion/TalkToYourself.cs
  2. +11
    -0
      LLama/LLamaContext.cs
  3. +57
    -0
      LLama/LLamaWeights.cs

+ 11
- 17
LLama.Examples/NewVersion/TalkToYourself.cs View File

@@ -2,8 +2,6 @@
using System.Text;
using LLama.Abstractions;
using LLama.Common;
using LLama.Extensions;
using LLama.Native;

namespace LLama.Examples.NewVersion
{
@@ -12,26 +10,20 @@ namespace LLama.Examples.NewVersion
public static async Task Run()
{
Console.Write("Please input your model path: ");
string modelPath = "C:\\Users\\Martin\\Documents\\Python\\oobabooga_windows\\text-generation-webui\\models\\llama-2-7b-chat.ggmlv3.q6_K.bin";
var modelPath = Console.ReadLine();

// todo: model path is passed here, but isn't needed
// Load weights into memory
var @params = new ModelParams(modelPath)
{
Seed = RandomNumberGenerator.GetInt32(int.MaxValue)
};
using var weights = LLamaWeights.LoadFromFile(@params);

// todo: all this pin stuff is ugly and should be hidden in the higher level wrapper
using var pin = @params.ToLlamaContextParams(out var lparams);

// todo: we need a higher level wrapper around the model weights (LLamaWeights??)
var weights = SafeLlamaModelHandle.LoadFromFile(modelPath, lparams);

// todo: need a method on the LLamaWeights which does this
var ctx1 = new LLamaContext(weights.CreateContext(lparams), @params, Encoding.UTF8);
var ctx2 = new LLamaContext(weights.CreateContext(lparams), @params, Encoding.UTF8);

var alice = new InteractiveExecutor(ctx1);
var bob = new InteractiveExecutor(ctx2);
// Create 2 contexts sharing the same weights
using var aliceCtx = weights.CreateContext(@params, Encoding.UTF8);
var alice = new InteractiveExecutor(aliceCtx);
using var bobCtx = weights.CreateContext(@params, Encoding.UTF8);
var bob = new InteractiveExecutor(bobCtx);

// Initial alice prompt
var alicePrompt = "Transcript of a dialog, where the Alice interacts a person named Bob. Alice is friendly, kind, honest and good at writing.\nAlice: Hello";
@@ -46,7 +38,9 @@ namespace LLama.Examples.NewVersion
{
aliceResponse = await Prompt(alice, ConsoleColor.Green, bobResponse, false, true);
bobResponse = await Prompt(bob, ConsoleColor.Red, aliceResponse, false, true);
Thread.Sleep(1000);

if (Console.KeyAvailable)
break;
}
}



+ 11
- 0
LLama/LLamaContext.cs View File

@@ -84,6 +84,17 @@ namespace LLama
_ctx = nativeContext;
}

public LLamaContext(LLamaWeights model, IModelParams @params, Encoding encoding, ILLamaLogger? logger = null)
{
Params = @params;

_logger = logger;
_encoding = encoding;

using var pin = @params.ToLlamaContextParams(out var lparams);
_ctx = SafeLLamaContextHandle.Create(model.NativeHandle, lparams);
}

/// <summary>
/// Create a copy of the current state of this context
/// </summary>


+ 57
- 0
LLama/LLamaWeights.cs View File

@@ -0,0 +1,57 @@
using System;
using System.Text;
using LLama.Common;
using LLama.Extensions;
using LLama.Native;

namespace LLama
{
/// <summary>
/// A set of model weights, loaded into memory.
/// </summary>
public class LLamaWeights
: IDisposable
{
private readonly SafeLlamaModelHandle _weights;

/// <summary>
/// The native handle, which is used in the native APIs
/// </summary>
/// <remarks>Be careful how you use this!</remarks>
public SafeLlamaModelHandle NativeHandle => _weights;

private LLamaWeights(SafeLlamaModelHandle weights)
{
_weights = weights;
}

/// <summary>
/// Load weights into memory
/// </summary>
/// <param name="params"></param>
/// <returns></returns>
public static LLamaWeights LoadFromFile(ModelParams @params)
{
using var pin = @params.ToLlamaContextParams(out var lparams);
var weights = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);
return new LLamaWeights(weights);
}

/// <inheritdoc />
public void Dispose()
{
_weights.Dispose();
}

/// <summary>
/// Create a llama_context using this model
/// </summary>
/// <param name="params"></param>
/// <param name="utf8"></param>
/// <returns></returns>
public LLamaContext CreateContext(ModelParams @params, Encoding utf8)
{
return new LLamaContext(this, @params, Encoding.UTF8);
}
}
}

Loading…
Cancel
Save