diff --git a/LLama/LLamaModel.cs b/LLama/LLamaModel.cs index 4bc18c1e..a45fed62 100644 --- a/LLama/LLamaModel.cs +++ b/LLama/LLamaModel.cs @@ -284,8 +284,8 @@ namespace LLama int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f, bool penalizeNL = true) { - var n_vocab = NativeApi.llama_n_vocab(_ctx); - var logits = Utils.GetLogits(_ctx, n_vocab); + var n_vocab = _ctx.VocabCount; + var logits = _ctx.GetLogits(); // Apply params.logit_bias map if(logitBias is not null) diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 5857b590..1239da27 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -257,8 +257,8 @@ namespace LLama.Native /// /// Token logits obtained from the last call to llama_eval() /// The logits for the last token are stored in the last row - /// Can be mutated in order to change the probabilities of the next token - /// Rows: n_tokens + /// Can be mutated in order to change the probabilities of the next token.
+ /// Rows: n_tokens
/// Cols: n_vocab ///
/// diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index 9e81de69..b20462d5 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -11,11 +11,29 @@ namespace LLama.Native public class SafeLLamaContextHandle : SafeLLamaHandleBase { + #region properties and fields + /// + /// Total number of tokens in vocabulary of this model + /// + public int VocabCount => ThrowIfDisposed().VocabCount; + + /// + /// Total number of tokens in the context + /// + public int ContextSize => ThrowIfDisposed().ContextSize; + + /// + /// Dimension of embedding vectors + /// + public int EmbeddingCount => ThrowIfDisposed().EmbeddingCount; + /// /// This field guarantees that a reference to the model is held for as long as this handle is held /// private SafeLlamaModelHandle? _model; + #endregion + #region construction/destruction /// /// Create a new SafeLLamaContextHandle /// @@ -44,6 +62,16 @@ namespace LLama.Native return true; } + private SafeLlamaModelHandle ThrowIfDisposed() + { + if (IsClosed) + throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - it has been disposed"); + if (_model == null || _model.IsClosed) + throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - `SafeLlamaModelHandle` has been disposed"); + + return _model; + } + /// /// Create a new llama_state for the given model /// @@ -59,6 +87,7 @@ namespace LLama.Native return new(ctx_ptr, model); } + #endregion /// /// Convert the given text into tokens @@ -70,6 +99,8 @@ namespace LLama.Native /// public int[] Tokenize(string text, bool add_bos, Encoding encoding) { + ThrowIfDisposed(); + // Calculate number of bytes in string, this is a pessimistic estimate of token count. It can't // possibly be more than this. var count = encoding.GetByteCount(text) + (add_bos ? 1 : 0); @@ -97,5 +128,25 @@ namespace LLama.Native ArrayPool.Shared.Return(temporaryArray); } } + + /// + /// Token logits obtained from the last call to llama_eval() + /// The logits for the last token are stored in the last row + /// Can be mutated in order to change the probabilities of the next token.
+ /// Rows: n_tokens
+ /// Cols: n_vocab + ///
+ /// + /// + public Span GetLogits() + { + var model = ThrowIfDisposed(); + + unsafe + { + var logits = NativeApi.llama_get_logits(this); + return new Span(logits, model.VocabCount); + } + } } } diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index 79714fea..dbb1b070 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -13,17 +13,17 @@ namespace LLama.Native /// /// Total number of tokens in vocabulary of this model /// - public int VocabCount { get; set; } + public int VocabCount { get; } /// /// Total number of tokens in the context /// - public int ContextSize { get; set; } + public int ContextSize { get; } /// /// Dimension of embedding vectors /// - public int EmbeddingCount { get; set; } + public int EmbeddingCount { get; } internal SafeLlamaModelHandle(IntPtr handle) : base(handle) diff --git a/LLama/Utils.cs b/LLama/Utils.cs index 7a1f5f42..05964c98 100644 --- a/LLama/Utils.cs +++ b/LLama/Utils.cs @@ -33,10 +33,13 @@ namespace LLama return ctx.Tokenize(text, add_bos, encoding); } - public static unsafe Span GetLogits(SafeLLamaContextHandle ctx, int length) + [Obsolete("Use SafeLLamaContextHandle GetLogits method instead")] + public static Span GetLogits(SafeLLamaContextHandle ctx, int length) { - var logits = NativeApi.llama_get_logits(ctx); - return new Span(logits, length); + if (length != ctx.VocabCount) + throw new ArgumentException("length must be the VocabSize"); + + return ctx.GetLogits(); } public static unsafe int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads)