Browse Source

Merge pull request #97 from martindevans/embedder_tests

Embedder Test
tags/v0.5.1
Martin Evans GitHub 2 years ago
parent
commit
df80ec9161
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 108 additions and 34 deletions
  1. +70
    -0
      LLama.Unittest/LLamaEmbedderTests.cs
  2. +38
    -34
      LLama/LLamaEmbedder.cs

+ 70
- 0
LLama.Unittest/LLamaEmbedderTests.cs View File

@@ -0,0 +1,70 @@
using LLama.Common;

namespace LLama.Unittest;

public class LLamaEmbedderTests
: IDisposable
{
private readonly LLamaEmbedder _embedder = new(new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin"));

public void Dispose()
{
_embedder.Dispose();
}

private static float Magnitude(float[] a)
{
return MathF.Sqrt(a.Zip(a, (x, y) => x * y).Sum());
}

private static void Normalize(float[] a)
{
var mag = Magnitude(a);
for (var i = 0; i < a.Length; i++)
a[i] /= mag;
}

private static float Dot(float[] a, float[] b)
{
Assert.Equal(a.Length, b.Length);
return a.Zip(b, (x, y) => x * y).Sum();
}

private static void AssertApproxStartsWith(float[] expected, float[] actual, float epsilon = 0.08f)
{
for (int i = 0; i < expected.Length; i++)
Assert.Equal(expected[i], actual[i], epsilon);
}

[Fact]
public void EmbedBasic()
{
var cat = _embedder.GetEmbeddings("cat");

Assert.NotNull(cat);
Assert.NotEmpty(cat);

// Expected value generate with llama.cpp embedding.exe
var expected = new float[] { -0.127304f, -0.678057f, -0.085244f, -0.956915f, -0.638633f };
AssertApproxStartsWith(expected, cat);
}

[Fact]
public void EmbedCompare()
{
var cat = _embedder.GetEmbeddings("cat");
var kitten = _embedder.GetEmbeddings("kitten");
var spoon = _embedder.GetEmbeddings("spoon");

Normalize(cat);
Normalize(kitten);
Normalize(spoon);

var close = Dot(cat, kitten);
var far = Dot(cat, spoon);

// This comparison seems backwards, but remember that with a
// dot product 1.0 means **identical** and 0.0 means **completely opposite**!
Assert.True(close > far);
}
}

+ 38
- 34
LLama/LLamaEmbedder.cs View File

@@ -1,6 +1,5 @@
using LLama.Native;
using System;
using System.Text;
using LLama.Exceptions;
using LLama.Abstractions;

@@ -12,22 +11,13 @@ namespace LLama
public class LLamaEmbedder
: IDisposable
{
private readonly SafeLLamaContextHandle _ctx;
private readonly LLamaContext _ctx;

/// <summary>
/// Dimension of embedding vectors
/// </summary>
public int EmbeddingSize => _ctx.EmbeddingSize;

/// <summary>
/// Warning: must ensure the original model has params.embedding = true;
/// </summary>
/// <param name="ctx"></param>
internal LLamaEmbedder(SafeLLamaContextHandle ctx)
{
_ctx = ctx;
}

/// <summary>
///
/// </summary>
@@ -35,52 +25,66 @@ namespace LLama
public LLamaEmbedder(IModelParams @params)
{
@params.EmbeddingMode = true;
_ctx = Utils.InitLLamaContextFromModelParams(@params);
using var weights = LLamaWeights.LoadFromFile(@params);
_ctx = weights.CreateContext(@params);
}

/// <summary>
/// Get the embeddings of the text.
/// </summary>
/// <param name="text"></param>
/// <param name="threads">Threads used for inference.</param>
/// <param name="threads">unused</param>
/// <param name="addBos">Add bos to the text.</param>
/// <param name="encoding"></param>
/// <param name="encoding">unused</param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public unsafe float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8")
[Obsolete("'threads' and 'encoding' parameters are no longer used")]
// ReSharper disable once MethodOverloadWithOptionalParameter
public float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8")
{
if (threads == -1)
{
threads = Math.Max(Environment.ProcessorCount / 2, 1);
}
return GetEmbeddings(text, addBos);
}

/// <summary>
/// Get the embeddings of the text.
/// </summary>
/// <param name="text"></param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public float[] GetEmbeddings(string text)
{
return GetEmbeddings(text, true);
}

/// <summary>
/// Get the embeddings of the text.
/// </summary>
/// <param name="text"></param>
/// <param name="addBos">Add bos to the text.</param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public float[] GetEmbeddings(string text, bool addBos)
{
if (addBos)
{
text = text.Insert(0, " ");
}

var embed_inp_array = _ctx.Tokenize(text, addBos, Encoding.GetEncoding(encoding));
var embed_inp_array = _ctx.Tokenize(text, addBos);

// TODO(Rinne): deal with log of prompt

if (embed_inp_array.Length > 0)
{
if (NativeApi.llama_eval(_ctx, embed_inp_array, embed_inp_array.Length, 0, threads) != 0)
{
throw new RuntimeError("Failed to eval.");
}
}
_ctx.Eval(embed_inp_array, 0);

int n_embed = NativeApi.llama_n_embd(_ctx);
var embeddings = NativeApi.llama_get_embeddings(_ctx);
if (embeddings == null)
unsafe
{
return Array.Empty<float>();
var embeddings = NativeApi.llama_get_embeddings(_ctx.NativeHandle);
if (embeddings == null)
return Array.Empty<float>();

return new Span<float>(embeddings, EmbeddingSize).ToArray();
}
var span = new Span<float>(embeddings, n_embed);
float[] res = new float[n_embed];
span.CopyTo(res.AsSpan());
return res;
}

/// <summary>


Loading…
Cancel
Save