Browse Source

Added some additional basic tests

tags/v0.5.1
Martin Evans 2 years ago
parent
commit
1b35be2e0c
4 changed files with 88 additions and 5 deletions
  1. +0
    -1
      LLama.Unittest/BasicTest.cs
  2. +36
    -0
      LLama.Unittest/LLamaContextTests.cs
  3. +44
    -0
      LLama.Unittest/LLamaEmbedderTests.cs
  4. +8
    -4
      LLama/LLamaEmbedder.cs

+ 0
- 1
LLama.Unittest/BasicTest.cs View File

@@ -1,4 +1,3 @@
using LLama;
using LLama.Common;

namespace LLama.Unittest


+ 36
- 0
LLama.Unittest/LLamaContextTests.cs View File

@@ -0,0 +1,36 @@
using System.Text;
using LLama.Common;

namespace LLama.Unittest
{
public class LLamaContextTests
: IDisposable
{
private readonly LLamaWeights _weights;
private readonly LLamaContext _context;

public LLamaContextTests()
{
var @params = new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin")
{
ContextSize = 768,
};
_weights = LLamaWeights.LoadFromFile(@params);
_context = _weights.CreateContext(@params, Encoding.UTF8);
}

public void Dispose()
{
_weights.Dispose();
_context.Dispose();
}

[Fact]
public void CheckProperties()
{
Assert.Equal(768, _context.ContextSize);
Assert.Equal(4096, _context.EmbeddingSize);
Assert.Equal(32000, _context.VocabCount);
}
}
}

+ 44
- 0
LLama.Unittest/LLamaEmbedderTests.cs View File

@@ -0,0 +1,44 @@
using LLama.Common;

namespace LLama.Unittest
{
public class LLamaEmbedderTests
: IDisposable
{
private readonly LLamaEmbedder _embedder = new(new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin"));

public void Dispose()
{
_embedder.Dispose();
}

private static float Dot(float[] a, float[] b)
{
Assert.Equal(a.Length, b.Length);
return a.Zip(b, (x, y) => x + y).Sum();
}

[Fact]
public void EmbedHello()
{
var hello = _embedder.GetEmbeddings("Hello");

Assert.NotNull(hello);
Assert.NotEmpty(hello);
Assert.Equal(_embedder.EmbeddingSize, hello.Length);
}

[Fact]
public void EmbedCompare()
{
var cat = _embedder.GetEmbeddings("cat");
var kitten = _embedder.GetEmbeddings("kitten");
var spoon = _embedder.GetEmbeddings("spoon");

var close = Dot(cat, kitten);
var far = Dot(cat, spoon);

Assert.True(close < far);
}
}
}

+ 8
- 4
LLama/LLamaEmbedder.cs View File

@@ -1,9 +1,7 @@
using LLama.Native;
using System;
using System.Collections.Generic;
using System.Text;
using LLama.Exceptions;
using System.Linq;
using LLama.Abstractions;

namespace LLama
@@ -11,9 +9,15 @@ namespace LLama
/// <summary>
/// The embedder for LLama, which supports getting embeddings from text.
/// </summary>
public class LLamaEmbedder : IDisposable
public class LLamaEmbedder
: IDisposable
{
SafeLLamaContextHandle _ctx;
private readonly SafeLLamaContextHandle _ctx;

/// <summary>
/// Dimension of embedding vectors
/// </summary>
public int EmbeddingSize => _ctx.EmbeddingSize;

/// <summary>
/// Warning: must ensure the original model has params.embedding = true;


Loading…
Cancel
Save