diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index c29b6b25..f60f3cd5 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -97,15 +97,18 @@ namespace LLama private float[] GetEmbeddingsArray() { - var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle); - if (embeddings == null || embeddings.Length == 0) + unsafe { - embeddings = NativeApi.llama_get_embeddings_seq(Context.NativeHandle, LLamaSeqId.Zero); - if (embeddings == null || embeddings.Length == 0) + var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle); + + if (embeddings == null) + embeddings = NativeApi.llama_get_embeddings_seq(Context.NativeHandle, LLamaSeqId.Zero); + + if (embeddings == null) return Array.Empty(); - } - return embeddings.ToArray(); + return new Span(embeddings, Context.EmbeddingSize).ToArray(); + } } private static void Normalize(Span embeddings) @@ -116,6 +119,7 @@ namespace LLama lengthSqr += value * value; var length = (float)Math.Sqrt(lengthSqr); + // Do not divide by length if it is zero if (length <= float.Epsilon) return;