|
|
|
@@ -1,6 +1,5 @@ |
|
|
|
using LLama.Native; |
|
|
|
using System; |
|
|
|
using System.Text; |
|
|
|
using LLama.Exceptions; |
|
|
|
using LLama.Abstractions; |
|
|
|
|
|
|
|
@@ -12,22 +11,13 @@ namespace LLama |
|
|
|
public class LLamaEmbedder |
|
|
|
: IDisposable |
|
|
|
{ |
|
|
|
private readonly SafeLLamaContextHandle _ctx; |
|
|
|
private readonly LLamaContext _ctx; |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Dimension of embedding vectors |
|
|
|
/// </summary> |
|
|
|
public int EmbeddingSize => _ctx.EmbeddingSize; |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Warning: must ensure the original model has params.embedding = true; |
|
|
|
/// </summary> |
|
|
|
/// <param name="ctx"></param> |
|
|
|
internal LLamaEmbedder(SafeLLamaContextHandle ctx) |
|
|
|
{ |
|
|
|
_ctx = ctx; |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// |
|
|
|
/// </summary> |
|
|
|
@@ -35,52 +25,66 @@ namespace LLama |
|
|
|
public LLamaEmbedder(IModelParams @params) |
|
|
|
{ |
|
|
|
@params.EmbeddingMode = true; |
|
|
|
_ctx = Utils.InitLLamaContextFromModelParams(@params); |
|
|
|
using var weights = LLamaWeights.LoadFromFile(@params); |
|
|
|
_ctx = weights.CreateContext(@params); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Get the embeddings of the text. |
|
|
|
/// </summary> |
|
|
|
/// <param name="text"></param> |
|
|
|
/// <param name="threads">Threads used for inference.</param> |
|
|
|
/// <param name="threads">unused</param> |
|
|
|
/// <param name="addBos">Add bos to the text.</param> |
|
|
|
/// <param name="encoding"></param> |
|
|
|
/// <param name="encoding">unused</param> |
|
|
|
/// <returns></returns> |
|
|
|
/// <exception cref="RuntimeError"></exception> |
|
|
|
public unsafe float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8") |
|
|
|
[Obsolete("'threads' and 'encoding' parameters are no longer used")] |
|
|
|
// ReSharper disable once MethodOverloadWithOptionalParameter |
|
|
|
public float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8") |
|
|
|
{ |
|
|
|
if (threads == -1) |
|
|
|
{ |
|
|
|
threads = Math.Max(Environment.ProcessorCount / 2, 1); |
|
|
|
} |
|
|
|
return GetEmbeddings(text, addBos); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Get the embeddings of the text. |
|
|
|
/// </summary> |
|
|
|
/// <param name="text"></param> |
|
|
|
/// <returns></returns> |
|
|
|
/// <exception cref="RuntimeError"></exception> |
|
|
|
public float[] GetEmbeddings(string text) |
|
|
|
{ |
|
|
|
return GetEmbeddings(text, true); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Get the embeddings of the text. |
|
|
|
/// </summary> |
|
|
|
/// <param name="text"></param> |
|
|
|
/// <param name="addBos">Add bos to the text.</param> |
|
|
|
/// <returns></returns> |
|
|
|
/// <exception cref="RuntimeError"></exception> |
|
|
|
public float[] GetEmbeddings(string text, bool addBos) |
|
|
|
{ |
|
|
|
if (addBos) |
|
|
|
{ |
|
|
|
text = text.Insert(0, " "); |
|
|
|
} |
|
|
|
|
|
|
|
var embed_inp_array = _ctx.Tokenize(text, addBos, Encoding.GetEncoding(encoding)); |
|
|
|
var embed_inp_array = _ctx.Tokenize(text, addBos); |
|
|
|
|
|
|
|
// TODO(Rinne): deal with log of prompt |
|
|
|
|
|
|
|
if (embed_inp_array.Length > 0) |
|
|
|
{ |
|
|
|
if (NativeApi.llama_eval(_ctx, embed_inp_array, embed_inp_array.Length, 0, threads) != 0) |
|
|
|
{ |
|
|
|
throw new RuntimeError("Failed to eval."); |
|
|
|
} |
|
|
|
} |
|
|
|
_ctx.Eval(embed_inp_array, 0); |
|
|
|
|
|
|
|
int n_embed = NativeApi.llama_n_embd(_ctx); |
|
|
|
var embeddings = NativeApi.llama_get_embeddings(_ctx); |
|
|
|
if (embeddings == null) |
|
|
|
unsafe |
|
|
|
{ |
|
|
|
return Array.Empty<float>(); |
|
|
|
var embeddings = NativeApi.llama_get_embeddings(_ctx.NativeHandle); |
|
|
|
if (embeddings == null) |
|
|
|
return Array.Empty<float>(); |
|
|
|
|
|
|
|
return new Span<float>(embeddings, EmbeddingSize).ToArray(); |
|
|
|
} |
|
|
|
var span = new Span<float>(embeddings, n_embed); |
|
|
|
float[] res = new float[n_embed]; |
|
|
|
span.CopyTo(res.AsSpan()); |
|
|
|
return res; |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
|