diff --git a/LLama/LLamaModel.cs b/LLama/LLamaModel.cs
index 4bc18c1e..a45fed62 100644
--- a/LLama/LLamaModel.cs
+++ b/LLama/LLamaModel.cs
@@ -284,8 +284,8 @@ namespace LLama
int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f,
bool penalizeNL = true)
{
- var n_vocab = NativeApi.llama_n_vocab(_ctx);
- var logits = Utils.GetLogits(_ctx, n_vocab);
+ var n_vocab = _ctx.VocabCount;
+ var logits = _ctx.GetLogits();
// Apply params.logit_bias map
if(logitBias is not null)
diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs
index 5857b590..1239da27 100644
--- a/LLama/Native/NativeApi.cs
+++ b/LLama/Native/NativeApi.cs
@@ -257,8 +257,8 @@ namespace LLama.Native
///
/// Token logits obtained from the last call to llama_eval()
/// The logits for the last token are stored in the last row
- /// Can be mutated in order to change the probabilities of the next token
- /// Rows: n_tokens
+ /// Can be mutated in order to change the probabilities of the next token.
+ /// Rows: n_tokens
/// Cols: n_vocab
///
///
diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs
index 9e81de69..b20462d5 100644
--- a/LLama/Native/SafeLLamaContextHandle.cs
+++ b/LLama/Native/SafeLLamaContextHandle.cs
@@ -11,11 +11,29 @@ namespace LLama.Native
public class SafeLLamaContextHandle
: SafeLLamaHandleBase
{
+ #region properties and fields
+ ///
+ /// Total number of tokens in vocabulary of this model
+ ///
+ public int VocabCount => ThrowIfDisposed().VocabCount;
+
+ ///
+ /// Total number of tokens in the context
+ ///
+ public int ContextSize => ThrowIfDisposed().ContextSize;
+
+ ///
+ /// Dimension of embedding vectors
+ ///
+ public int EmbeddingCount => ThrowIfDisposed().EmbeddingCount;
+
///
/// This field guarantees that a reference to the model is held for as long as this handle is held
///
private SafeLlamaModelHandle? _model;
+ #endregion
+ #region construction/destruction
///
/// Create a new SafeLLamaContextHandle
///
@@ -44,6 +62,16 @@ namespace LLama.Native
return true;
}
+ private SafeLlamaModelHandle ThrowIfDisposed()
+ {
+ if (IsClosed)
+ throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - it has been disposed");
+ if (_model == null || _model.IsClosed)
+ throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - `SafeLlamaModelHandle` has been disposed");
+
+ return _model;
+ }
+
///
/// Create a new llama_state for the given model
///
@@ -59,6 +87,7 @@ namespace LLama.Native
return new(ctx_ptr, model);
}
+ #endregion
///
/// Convert the given text into tokens
@@ -70,6 +99,8 @@ namespace LLama.Native
///
public int[] Tokenize(string text, bool add_bos, Encoding encoding)
{
+ ThrowIfDisposed();
+
// Calculate number of bytes in string, this is a pessimistic estimate of token count. It can't
// possibly be more than this.
var count = encoding.GetByteCount(text) + (add_bos ? 1 : 0);
@@ -97,5 +128,25 @@ namespace LLama.Native
ArrayPool.Shared.Return(temporaryArray);
}
}
+
+ ///
+ /// Token logits obtained from the last call to llama_eval()
+ /// The logits for the last token are stored in the last row
+ /// Can be mutated in order to change the probabilities of the next token.
+ /// Rows: n_tokens
+ /// Cols: n_vocab
+ ///
+ ///
+ ///
+ public Span GetLogits()
+ {
+ var model = ThrowIfDisposed();
+
+ unsafe
+ {
+ var logits = NativeApi.llama_get_logits(this);
+ return new Span(logits, model.VocabCount);
+ }
+ }
}
}
diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs
index 79714fea..dbb1b070 100644
--- a/LLama/Native/SafeLlamaModelHandle.cs
+++ b/LLama/Native/SafeLlamaModelHandle.cs
@@ -13,17 +13,17 @@ namespace LLama.Native
///
/// Total number of tokens in vocabulary of this model
///
- public int VocabCount { get; set; }
+ public int VocabCount { get; }
///
/// Total number of tokens in the context
///
- public int ContextSize { get; set; }
+ public int ContextSize { get; }
///
/// Dimension of embedding vectors
///
- public int EmbeddingCount { get; set; }
+ public int EmbeddingCount { get; }
internal SafeLlamaModelHandle(IntPtr handle)
: base(handle)
diff --git a/LLama/Utils.cs b/LLama/Utils.cs
index 7a1f5f42..05964c98 100644
--- a/LLama/Utils.cs
+++ b/LLama/Utils.cs
@@ -33,10 +33,13 @@ namespace LLama
return ctx.Tokenize(text, add_bos, encoding);
}
- public static unsafe Span GetLogits(SafeLLamaContextHandle ctx, int length)
+ [Obsolete("Use SafeLLamaContextHandle GetLogits method instead")]
+ public static Span GetLogits(SafeLLamaContextHandle ctx, int length)
{
- var logits = NativeApi.llama_get_logits(ctx);
- return new Span(logits, length);
+ if (length != ctx.VocabCount)
+ throw new ArgumentException("length must be the VocabSize");
+
+ return ctx.GetLogits();
}
public static unsafe int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads)