Browse Source

Assorted cleanup leftover after the huge change in the last PR (comments, syntax style, etc)

tags/v0.6.0
Martin Evans 2 years ago
parent
commit
9daf586ba8
22 changed files with 257 additions and 201 deletions
  1. +0
    -5
      LLama/Abstractions/IModelParams.cs
  2. +5
    -6
      LLama/Common/FixedSizeQueue.cs
  3. +3
    -3
      LLama/Common/ModelParams.cs
  4. +1
    -1
      LLama/Exceptions/GrammarFormatExceptions.cs
  5. +14
    -13
      LLama/Exceptions/RuntimeError.cs
  6. +0
    -1
      LLama/Extensions/EncodingExtensions.cs
  7. +17
    -20
      LLama/Extensions/KeyValuePairExtensions.cs
  8. +46
    -46
      LLama/Grammars/GBNFGrammarParser.cs
  9. +0
    -1
      LLama/Grammars/Grammar.cs
  10. +37
    -39
      LLama/LLamaContext.cs
  11. +16
    -0
      LLama/LLamaEmbedder.cs
  12. +1
    -1
      LLama/LLamaExecutorBase.cs
  13. +4
    -4
      LLama/LLamaInstructExecutor.cs
  14. +1
    -0
      LLama/LLamaQuantizer.cs
  15. +4
    -6
      LLama/LLamaWeights.cs
  16. +36
    -11
      LLama/Native/LLamaBatchSafeHandle.cs
  17. +1
    -1
      LLama/Native/LLamaGrammarElement.cs
  18. +18
    -7
      LLama/Native/LLamaPos.cs
  19. +18
    -7
      LLama/Native/LLamaSeqId.cs
  20. +24
    -24
      LLama/Native/LLamaTokenData.cs
  21. +9
    -3
      LLama/Native/SafeLLamaContextHandle.cs
  22. +2
    -2
      LLama/Native/SafeLlamaModelHandle.cs

+ 0
- 5
LLama/Abstractions/IModelParams.cs View File

@@ -34,11 +34,6 @@ namespace LLama.Abstractions
/// </summary>
string ModelPath { get; set; }

/// <summary>
/// Number of threads (-1 = autodetect) (n_threads)
/// </summary>
uint? Threads { get; set; }

/// <summary>
/// how split tensors should be distributed across GPUs
/// </summary>


+ 5
- 6
LLama/Common/FixedSizeQueue.cs View File

@@ -12,7 +12,6 @@ namespace LLama.Common
public class FixedSizeQueue<T>
: IEnumerable<T>
{
private readonly int _maxSize;
private readonly List<T> _storage;

internal IReadOnlyList<T> Items => _storage;
@@ -25,7 +24,7 @@ namespace LLama.Common
/// <summary>
/// Maximum number of items allowed in this queue
/// </summary>
public int Capacity => _maxSize;
public int Capacity { get; }

/// <summary>
/// Create a new queue
@@ -33,7 +32,7 @@ namespace LLama.Common
/// <param name="size">the maximum number of items to store in this queue</param>
public FixedSizeQueue(int size)
{
_maxSize = size;
Capacity = size;
_storage = new();
}

@@ -52,11 +51,11 @@ namespace LLama.Common
#endif

// Size of "data" is unknown, copy it all into a list
_maxSize = size;
Capacity = size;
_storage = new List<T>(data);

// Now check if that list is a valid size.
if (_storage.Count > _maxSize)
if (_storage.Count > Capacity)
throw new ArgumentException($"The max size set for the quene is {size}, but got {_storage.Count} initial values.");
}

@@ -81,7 +80,7 @@ namespace LLama.Common
public void Enqueue(T item)
{
_storage.Add(item);
if(_storage.Count >= _maxSize)
if(_storage.Count >= Capacity)
{
_storage.RemoveAt(0);
}


+ 3
- 3
LLama/Common/ModelParams.cs View File

@@ -40,11 +40,11 @@ namespace LLama.Common
/// <summary>
/// Use mlock to keep model in memory (use_mlock)
/// </summary>
public bool UseMemoryLock { get; set; } = false;
public bool UseMemoryLock { get; set; }
/// <summary>
/// Compute perplexity over the prompt (perplexity)
/// </summary>
public bool Perplexity { get; set; } = false;
public bool Perplexity { get; set; }
/// <summary>
/// Model path (model)
/// </summary>
@@ -79,7 +79,7 @@ namespace LLama.Common
/// Whether to use embedding mode. (embedding) Note that if this is set to true,
/// The LLamaModel won't produce text response anymore.
/// </summary>
public bool EmbeddingMode { get; set; } = false;
public bool EmbeddingMode { get; set; }

/// <summary>
/// how split tensors should be distributed across GPUs


+ 1
- 1
LLama/Exceptions/GrammarFormatExceptions.cs View File

@@ -58,7 +58,7 @@ public class GrammarUnexpectedEndOfInput
: GrammarFormatException
{
internal GrammarUnexpectedEndOfInput()
: base($"Unexpected end of input")
: base("Unexpected end of input")
{
}
}


+ 14
- 13
LLama/Exceptions/RuntimeError.cs View File

@@ -1,19 +1,20 @@
using System;

namespace LLama.Exceptions
namespace LLama.Exceptions;

/// <summary>
/// Base class for LLamaSharp runtime errors (i.e. errors produced by llama.cpp, converted into exceptions)
/// </summary>
public class RuntimeError
: Exception
{
public class RuntimeError
: Exception
/// <summary>
/// Create a new RuntimeError
/// </summary>
/// <param name="message"></param>
public RuntimeError(string message)
: base(message)
{
public RuntimeError()
{

}

public RuntimeError(string message)
: base(message)
{

}
}
}
}

+ 0
- 1
LLama/Extensions/EncodingExtensions.cs View File

@@ -1,5 +1,4 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Extensions;


+ 17
- 20
LLama/Extensions/KeyValuePairExtensions.cs View File

@@ -1,26 +1,23 @@
using System.Collections.Generic;
namespace LLama.Extensions;

namespace LLama.Extensions
/// <summary>
/// Extensions to the KeyValuePair struct
/// </summary>
internal static class KeyValuePairExtensions
{
#if NETSTANDARD2_0
/// <summary>
/// Extensions to the KeyValuePair struct
/// Deconstruct a KeyValuePair into it's constituent parts.
/// </summary>
internal static class KeyValuePairExtensions
/// <param name="pair">The KeyValuePair to deconstruct</param>
/// <param name="first">First element, the Key</param>
/// <param name="second">Second element, the Value</param>
/// <typeparam name="TKey">Type of the Key</typeparam>
/// <typeparam name="TValue">Type of the Value</typeparam>
public static void Deconstruct<TKey, TValue>(this System.Collections.Generic.KeyValuePair<TKey, TValue> pair, out TKey first, out TValue second)
{
#if NETSTANDARD2_0
/// <summary>
/// Deconstruct a KeyValuePair into it's constituent parts.
/// </summary>
/// <param name="pair">The KeyValuePair to deconstruct</param>
/// <param name="first">First element, the Key</param>
/// <param name="second">Second element, the Value</param>
/// <typeparam name="TKey">Type of the Key</typeparam>
/// <typeparam name="TValue">Type of the Value</typeparam>
public static void Deconstruct<TKey, TValue>(this KeyValuePair<TKey, TValue> pair, out TKey first, out TValue second)
{
first = pair.Key;
second = pair.Value;
}
#endif
first = pair.Key;
second = pair.Value;
}
}
#endif
}

+ 46
- 46
LLama/Grammars/GBNFGrammarParser.cs View File

@@ -17,7 +17,7 @@ namespace LLama.Grammars
{
// NOTE: assumes valid utf8 (but checks for overrun)
// copied from llama.cpp
private uint DecodeUTF8(ref ReadOnlySpan<byte> src)
private static uint DecodeUTF8(ref ReadOnlySpan<byte> src)
{
int[] lookup = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };

@@ -40,46 +40,12 @@ namespace LLama.Grammars
return value;
}

private uint GetSymbolId(ParseState state, ReadOnlySpan<byte> src, int len)
{
uint nextId = (uint)state.SymbolIds.Count;
string key = Encoding.UTF8.GetString(src.Slice(0, src.Length - len).ToArray());

if (state.SymbolIds.TryGetValue(key, out uint existingId))
{
return existingId;
}
else
{
state.SymbolIds[key] = nextId;
return nextId;
}
}

private uint GenerateSymbolId(ParseState state, string baseName)
{
uint nextId = (uint)state.SymbolIds.Count;
string key = $"{baseName}_{nextId}";
state.SymbolIds[key] = nextId;
return nextId;
}

private void AddRule(ParseState state, uint ruleId, List<LLamaGrammarElement> rule)
{
while (state.Rules.Count <= ruleId)
{
state.Rules.Add(new List<LLamaGrammarElement>());
}

state.Rules[(int)ruleId] = rule;
}

private bool IsWordChar(byte c)
private static bool IsWordChar(byte c)
{
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
}

private uint ParseHex(ref ReadOnlySpan<byte> src, int size)
private static uint ParseHex(ref ReadOnlySpan<byte> src, int size)
{
int pos = 0;
int end = size;
@@ -115,7 +81,7 @@ namespace LLama.Grammars
return value;
}

private ReadOnlySpan<byte> ParseSpace(ReadOnlySpan<byte> src, bool newlineOk)
private static ReadOnlySpan<byte> ParseSpace(ReadOnlySpan<byte> src, bool newlineOk)
{
int pos = 0;
while (pos < src.Length &&
@@ -137,7 +103,7 @@ namespace LLama.Grammars
return src.Slice(pos);
}

private ReadOnlySpan<byte> ParseName(ReadOnlySpan<byte> src)
private static ReadOnlySpan<byte> ParseName(ReadOnlySpan<byte> src)
{
int pos = 0;
while (pos < src.Length && IsWordChar(src[pos]))
@@ -151,7 +117,7 @@ namespace LLama.Grammars
return src.Slice(pos);
}

private uint ParseChar(ref ReadOnlySpan<byte> src)
private static uint ParseChar(ref ReadOnlySpan<byte> src)
{
if (src[0] == '\\')
{
@@ -235,7 +201,7 @@ namespace LLama.Grammars
else if (IsWordChar(pos[0])) // rule reference
{
var nameEnd = ParseName(pos);
uint refRuleId = GetSymbolId(state, pos, nameEnd.Length);
uint refRuleId = state.GetSymbolId(pos, nameEnd.Length);
pos = ParseSpace(nameEnd, isNested);
lastSymStart = outElements.Count;
outElements.Add(new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, refRuleId));
@@ -244,7 +210,7 @@ namespace LLama.Grammars
{
// parse nested alternates into synthesized rule
pos = ParseSpace(pos.Slice(1), true);
uint subRuleId = GenerateSymbolId(state, ruleName);
uint subRuleId = state.GenerateSymbolId(ruleName);
pos = ParseAlternates(state, pos, ruleName, subRuleId, true);
lastSymStart = outElements.Count;
// output reference to synthesized rule
@@ -263,7 +229,7 @@ namespace LLama.Grammars
// S* --> S' ::= S S' |
// S+ --> S' ::= S S' | S
// S? --> S' ::= S |
uint subRuleId = GenerateSymbolId(state, ruleName);
uint subRuleId = state.GenerateSymbolId(ruleName);

List<LLamaGrammarElement> subRule = new List<LLamaGrammarElement>();

@@ -287,7 +253,7 @@ namespace LLama.Grammars

subRule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0));

AddRule(state, subRuleId, subRule);
state.AddRule(subRuleId, subRule);

// in original rule, replace previous symbol with reference to generated rule
outElements.RemoveRange(lastSymStart, outElements.Count - lastSymStart);
@@ -323,7 +289,7 @@ namespace LLama.Grammars
}

rule.Add(new LLamaGrammarElement(LLamaGrammarElementType.END, 0));
AddRule(state, ruleId, rule);
state.AddRule(ruleId, rule);

return pos;
}
@@ -333,7 +299,7 @@ namespace LLama.Grammars
ReadOnlySpan<byte> nameEnd = ParseName(src);
ReadOnlySpan<byte> pos = ParseSpace(nameEnd, false);
int nameLen = src.Length - nameEnd.Length;
uint ruleId = GetSymbolId(state, src.Slice(0, nameLen), 0);
uint ruleId = state.GetSymbolId(src.Slice(0, nameLen), 0);
string name = Encoding.UTF8.GetString(src.Slice(0, nameLen).ToArray());

if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '='))
@@ -393,6 +359,40 @@ namespace LLama.Grammars
{
public SortedDictionary<string, uint> SymbolIds { get; } = new();
public List<List<LLamaGrammarElement>> Rules { get; } = new();

public uint GetSymbolId(ReadOnlySpan<byte> src, int len)
{
var nextId = (uint)SymbolIds.Count;
var key = Encoding.UTF8.GetString(src.Slice(0, src.Length - len).ToArray());

if (SymbolIds.TryGetValue(key, out uint existingId))
{
return existingId;
}
else
{
SymbolIds[key] = nextId;
return nextId;
}
}

public uint GenerateSymbolId(string baseName)
{
var nextId = (uint)SymbolIds.Count;
var key = $"{baseName}_{nextId}";
SymbolIds[key] = nextId;
return nextId;
}

public void AddRule(uint ruleId, List<LLamaGrammarElement> rule)
{
while (Rules.Count <= ruleId)
{
Rules.Add(new List<LLamaGrammarElement>());
}

Rules[(int)ruleId] = rule;
}
}
}
}

+ 0
- 1
LLama/Grammars/Grammar.cs View File

@@ -112,7 +112,6 @@ namespace LLama.Grammars
case LLamaGrammarElementType.CHAR_ALT:
PrintGrammarChar(output, elem.Value);
break;

}

if (elem.IsCharElement())


+ 37
- 39
LLama/LLamaContext.cs View File

@@ -23,23 +23,21 @@ namespace LLama
: IDisposable
{
private readonly ILogger? _logger;
private readonly Encoding _encoding;
private readonly SafeLLamaContextHandle _ctx;

/// <summary>
/// Total number of tokens in vocabulary of this model
/// </summary>
public int VocabCount => _ctx.VocabCount;
public int VocabCount => NativeHandle.VocabCount;

/// <summary>
/// Total number of tokens in the context
/// </summary>
public int ContextSize => _ctx.ContextSize;
public int ContextSize => NativeHandle.ContextSize;

/// <summary>
/// Dimension of embedding vectors
/// </summary>
public int EmbeddingSize => _ctx.EmbeddingSize;
public int EmbeddingSize => NativeHandle.EmbeddingSize;

/// <summary>
/// The context params set for this context
@@ -50,20 +48,20 @@ namespace LLama
/// The native handle, which is used to be passed to the native APIs
/// </summary>
/// <remarks>Be careful how you use this!</remarks>
public SafeLLamaContextHandle NativeHandle => _ctx;
public SafeLLamaContextHandle NativeHandle { get; }

/// <summary>
/// The encoding set for this model to deal with text input.
/// </summary>
public Encoding Encoding => _encoding;
public Encoding Encoding { get; }

internal LLamaContext(SafeLLamaContextHandle nativeContext, IContextParams @params, ILogger? logger = null)
{
Params = @params;

_logger = logger;
_encoding = @params.Encoding;
_ctx = nativeContext;
Encoding = @params.Encoding;
NativeHandle = nativeContext;
}

/// <summary>
@@ -81,10 +79,10 @@ namespace LLama
Params = @params;

_logger = logger;
_encoding = @params.Encoding;
Encoding = @params.Encoding;

@params.ToLlamaContextParams(out var lparams);
_ctx = SafeLLamaContextHandle.Create(model.NativeHandle, lparams);
NativeHandle = SafeLLamaContextHandle.Create(model.NativeHandle, lparams);
}

/// <summary>
@@ -96,7 +94,7 @@ namespace LLama
/// <returns></returns>
public llama_token[] Tokenize(string text, bool addBos = true, bool special = false)
{
return _ctx.Tokenize(text, addBos, special, _encoding);
return NativeHandle.Tokenize(text, addBos, special, Encoding);
}

/// <summary>
@@ -108,7 +106,7 @@ namespace LLama
{
var sb = new StringBuilder();
foreach (var token in tokens)
_ctx.TokenToString(token, _encoding, sb);
NativeHandle.TokenToString(token, Encoding, sb);

return sb.ToString();
}
@@ -124,7 +122,7 @@ namespace LLama
File.Delete(filename);

// Estimate size of state to write to disk, this is always equal to or greater than the actual size
var estimatedStateSize = (long)NativeApi.llama_get_state_size(_ctx);
var estimatedStateSize = (long)NativeApi.llama_get_state_size(NativeHandle);

// Map the file and write the bytes directly to it. This saves copying the bytes into a C# array
long writtenBytes;
@@ -135,7 +133,7 @@ namespace LLama
{
byte* ptr = null;
view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr);
writtenBytes = (long)NativeApi.llama_copy_state_data(_ctx, ptr);
writtenBytes = (long)NativeApi.llama_copy_state_data(NativeHandle, ptr);
view.SafeMemoryMappedViewHandle.ReleasePointer();
}
}
@@ -151,14 +149,14 @@ namespace LLama
/// <returns></returns>
public State GetState()
{
var stateSize = _ctx.GetStateSize();
var stateSize = NativeHandle.GetStateSize();

// Allocate a chunk of memory large enough to hold the entire state
var memory = Marshal.AllocHGlobal((nint)stateSize);
try
{
// Copy the state data into memory, discover the actual size required
var actualSize = _ctx.GetState(memory, stateSize);
var actualSize = NativeHandle.GetState(memory, stateSize);

// Shrink to size
memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize);
@@ -193,7 +191,7 @@ namespace LLama
{
byte* ptr = null;
view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr);
NativeApi.llama_set_state_data(_ctx, ptr);
NativeApi.llama_set_state_data(NativeHandle, ptr);
view.SafeMemoryMappedViewHandle.ReleasePointer();
}
}
@@ -208,7 +206,7 @@ namespace LLama
{
unsafe
{
_ctx.SetState((byte*)state.DangerousGetHandle().ToPointer());
NativeHandle.SetState((byte*)state.DangerousGetHandle().ToPointer());
}
}

@@ -235,13 +233,13 @@ namespace LLama

if (grammar != null)
{
SamplingApi.llama_sample_grammar(_ctx, candidates, grammar);
SamplingApi.llama_sample_grammar(NativeHandle, candidates, grammar);
}

if (temperature <= 0)
{
// Greedy sampling
id = SamplingApi.llama_sample_token_greedy(_ctx, candidates);
id = SamplingApi.llama_sample_token_greedy(NativeHandle, candidates);
}
else
{
@@ -250,23 +248,23 @@ namespace LLama
if (mirostat == MirostatType.Mirostat)
{
const int mirostat_m = 100;
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mu);
SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat(NativeHandle, candidates, mirostatTau, mirostatEta, mirostat_m, ref mu);
}
else if (mirostat == MirostatType.Mirostat2)
{
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mu);
SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat_v2(NativeHandle, candidates, mirostatTau, mirostatEta, ref mu);
}
else
{
// Temperature sampling
SamplingApi.llama_sample_top_k(_ctx, candidates, topK, 1);
SamplingApi.llama_sample_tail_free(_ctx, candidates, tfsZ, 1);
SamplingApi.llama_sample_typical(_ctx, candidates, typicalP, 1);
SamplingApi.llama_sample_top_p(_ctx, candidates, topP, 1);
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token(_ctx, candidates);
SamplingApi.llama_sample_top_k(NativeHandle, candidates, topK, 1);
SamplingApi.llama_sample_tail_free(NativeHandle, candidates, tfsZ, 1);
SamplingApi.llama_sample_typical(NativeHandle, candidates, typicalP, 1);
SamplingApi.llama_sample_top_p(NativeHandle, candidates, topP, 1);
SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature);
id = SamplingApi.llama_sample_token(NativeHandle, candidates);
}
}
mirostat_mu = mu;
@@ -274,7 +272,7 @@ namespace LLama

if (grammar != null)
{
NativeApi.llama_grammar_accept_token(_ctx, grammar, id);
NativeApi.llama_grammar_accept_token(NativeHandle, grammar, id);
}

return id;
@@ -295,7 +293,7 @@ namespace LLama
int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f,
bool penalizeNL = true)
{
var logits = _ctx.GetLogits();
var logits = NativeHandle.GetLogits();

// Apply params.logit_bias map
if (logitBias is not null)
@@ -305,7 +303,7 @@ namespace LLama
}

// Save the newline logit value
var nl_token = NativeApi.llama_token_nl(_ctx);
var nl_token = NativeApi.llama_token_nl(NativeHandle);
var nl_logit = logits[nl_token];

// Convert logits into token candidates
@@ -316,8 +314,8 @@ namespace LLama
var last_n_array = lastTokens.TakeLast(last_n_repeat).ToArray();

// Apply penalties to candidates
SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p, last_n_array, repeatPenalty);
SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates_p, last_n_array, alphaFrequency, alphaPresence);
SamplingApi.llama_sample_repetition_penalty(NativeHandle, candidates_p, last_n_array, repeatPenalty);
SamplingApi.llama_sample_frequency_and_presence_penalties(NativeHandle, candidates_p, last_n_array, alphaFrequency, alphaPresence);

// Restore newline token logit value if necessary
if (!penalizeNL)
@@ -408,9 +406,9 @@ namespace LLama
n_eval = (int)Params.BatchSize;
}

if (!_ctx.Eval(tokens.Slice(i, n_eval), pastTokensCount))
if (!NativeHandle.Eval(tokens.Slice(i, n_eval), pastTokensCount))
{
_logger?.LogError($"[LLamaContext] Failed to eval.");
_logger?.LogError("[LLamaContext] Failed to eval.");
throw new RuntimeError("Failed to eval.");
}

@@ -443,7 +441,7 @@ namespace LLama
/// <inheritdoc />
public void Dispose()
{
_ctx.Dispose();
NativeHandle.Dispose();
}

/// <summary>


+ 16
- 0
LLama/LLamaEmbedder.cs View File

@@ -18,11 +18,22 @@ namespace LLama
/// </summary>
public int EmbeddingSize => _ctx.EmbeddingSize;

/// <summary>
/// Create a new embedder (loading temporary weights)
/// </summary>
/// <param name="allParams"></param>
[Obsolete("Preload LLamaWeights and use the constructor which accepts them")]
public LLamaEmbedder(ILLamaParams allParams)
: this(allParams, allParams)
{
}

/// <summary>
/// Create a new embedder (loading temporary weights)
/// </summary>
/// <param name="modelParams"></param>
/// <param name="contextParams"></param>
[Obsolete("Preload LLamaWeights and use the constructor which accepts them")]
public LLamaEmbedder(IModelParams modelParams, IContextParams contextParams)
{
using var weights = LLamaWeights.LoadFromFile(modelParams);
@@ -31,6 +42,11 @@ namespace LLama
_ctx = weights.CreateContext(contextParams);
}

/// <summary>
/// Create a new embedder, using the given LLamaWeights
/// </summary>
/// <param name="weights"></param>
/// <param name="params"></param>
public LLamaEmbedder(LLamaWeights weights, IContextParams @params)
{
@params.EmbeddingMode = true;


+ 1
- 1
LLama/LLamaExecutorBase.cs View File

@@ -114,7 +114,7 @@ namespace LLama
}
else
{
_logger?.LogWarning($"[LLamaExecutor] Session file does not exist, will create");
_logger?.LogWarning("[LLamaExecutor] Session file does not exist, will create");
}

_n_matching_session_tokens = 0;


+ 4
- 4
LLama/LLamaInstructExecutor.cs View File

@@ -18,10 +18,10 @@ namespace LLama
/// </summary>
public class InstructExecutor : StatefulExecutorBase
{
bool _is_prompt_run = true;
string _instructionPrefix;
llama_token[] _inp_pfx;
llama_token[] _inp_sfx;
private bool _is_prompt_run = true;
private readonly string _instructionPrefix;
private llama_token[] _inp_pfx;
private llama_token[] _inp_sfx;

/// <summary>
///


+ 1
- 0
LLama/LLamaQuantizer.cs View File

@@ -80,6 +80,7 @@ namespace LLama
return true;

case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16:
case LLamaFtype.LLAMA_FTYPE_GUESSED:
default:
return false;
}


+ 4
- 6
LLama/LLamaWeights.cs View File

@@ -11,13 +11,11 @@ namespace LLama
public sealed class LLamaWeights
: IDisposable
{
private readonly SafeLlamaModelHandle _weights;

/// <summary>
/// The native handle, which is used in the native APIs
/// </summary>
/// <remarks>Be careful how you use this!</remarks>
public SafeLlamaModelHandle NativeHandle => _weights;
public SafeLlamaModelHandle NativeHandle { get; }

/// <summary>
/// Total number of tokens in vocabulary of this model
@@ -46,7 +44,7 @@ namespace LLama

internal LLamaWeights(SafeLlamaModelHandle weights)
{
_weights = weights;
NativeHandle = weights;
}

/// <summary>
@@ -66,7 +64,7 @@ namespace LLama
if (adapter.Scale <= 0)
continue;

weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase, @params.Threads);
weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase);
}

return new LLamaWeights(weights);
@@ -75,7 +73,7 @@ namespace LLama
/// <inheritdoc />
public void Dispose()
{
_weights.Dispose();
NativeHandle.Dispose();
}

/// <summary>


+ 36
- 11
LLama/Native/LLamaBatchSafeHandle.cs View File

@@ -4,11 +4,18 @@ namespace LLama.Native;

using llama_token = Int32;

/// <summary>
/// Input data for llama_decode. A llama_batch object can contain input about one or many sequences.
/// </summary>
public sealed class LLamaBatchSafeHandle
: SafeLLamaHandleBase
{
private readonly int _embd;
public LLamaNativeBatch Batch { get; private set; }

/// <summary>
/// Get the native llama_batch struct
/// </summary>
public LLamaNativeBatch NativeBatch { get; private set; }

/// <summary>
/// the token ids of the input (used when embd is NULL)
@@ -22,7 +29,7 @@ public sealed class LLamaBatchSafeHandle
if (_embd != 0)
return new Span<int>(null, 0);
else
return new Span<int>(Batch.token, Batch.n_tokens);
return new Span<int>(NativeBatch.token, NativeBatch.n_tokens);
}
}
}
@@ -37,10 +44,10 @@ public sealed class LLamaBatchSafeHandle
unsafe
{
// If embd != 0, llama_batch.embd will be allocated with size of n_tokens *embd * sizeof(float)
/// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token

if (_embd != 0)
return new Span<llama_token>(Batch.embd, Batch.n_tokens * _embd);
return new Span<llama_token>(NativeBatch.embd, NativeBatch.n_tokens * _embd);
else
return new Span<llama_token>(null, 0);
}
@@ -56,7 +63,7 @@ public sealed class LLamaBatchSafeHandle
{
unsafe
{
return new Span<LLamaPos>(Batch.pos, Batch.n_tokens);
return new Span<LLamaPos>(NativeBatch.pos, NativeBatch.n_tokens);
}
}
}
@@ -70,7 +77,7 @@ public sealed class LLamaBatchSafeHandle
{
unsafe
{
return new Span<LLamaSeqId>(Batch.seq_id, Batch.n_tokens);
return new Span<LLamaSeqId>(NativeBatch.seq_id, NativeBatch.n_tokens);
}
}
}
@@ -84,22 +91,40 @@ public sealed class LLamaBatchSafeHandle
{
unsafe
{
return new Span<byte>(Batch.logits, Batch.n_tokens);
return new Span<byte>(NativeBatch.logits, NativeBatch.n_tokens);
}
}
}

public LLamaBatchSafeHandle(int n_tokens, int embd)
/// <summary>
/// Create a safe handle owning a `LLamaNativeBatch`
/// </summary>
/// <param name="batch"></param>
/// <param name="embd"></param>
public LLamaBatchSafeHandle(LLamaNativeBatch batch, int embd)
: base((nint)1)
{
_embd = embd;
Batch = NativeApi.llama_batch_init(n_tokens, embd);
NativeBatch = batch;
}

/// <summary>
/// Call `llama_batch_init` and create a new batch
/// </summary>
/// <param name="n_tokens"></param>
/// <param name="embd"></param>
/// <returns></returns>
public static LLamaBatchSafeHandle Create(int n_tokens, int embd)
{
var batch = NativeApi.llama_batch_init(n_tokens, embd);
return new LLamaBatchSafeHandle(batch, embd);
}

/// <inheritdoc />
protected override bool ReleaseHandle()
{
NativeApi.llama_batch_free(Batch);
Batch = default;
NativeApi.llama_batch_free(NativeBatch);
NativeBatch = default;
SetHandle(IntPtr.Zero);
return true;
}

+ 1
- 1
LLama/Native/LLamaGrammarElement.cs View File

@@ -45,7 +45,7 @@ namespace LLama.Native
/// CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
/// </summary>
CHAR_ALT = 6,
};
}

/// <summary>
/// An element of a grammar


+ 18
- 7
LLama/Native/LLamaPos.cs View File

@@ -1,15 +1,26 @@
namespace LLama.Native;

public record struct LLamaPos
/// <summary>
/// Indicates position in a sequence
/// </summary>
public readonly record struct LLamaPos(int Value)
{
public int Value;

public LLamaPos(int value)
{
Value = value;
}
/// <summary>
/// The raw value
/// </summary>
public readonly int Value = Value;

/// <summary>
/// Convert a LLamaPos into an integer (extract the raw value)
/// </summary>
/// <param name="pos"></param>
/// <returns></returns>
public static explicit operator int(LLamaPos pos) => pos.Value;

/// <summary>
/// Convert an integer into a LLamaPos
/// </summary>
/// <param name="value"></param>
/// <returns></returns>
public static implicit operator LLamaPos(int value) => new(value);
}

+ 18
- 7
LLama/Native/LLamaSeqId.cs View File

@@ -1,15 +1,26 @@
namespace LLama.Native;

public record struct LLamaSeqId
/// <summary>
/// ID for a sequence in a batch
/// </summary>
/// <param name="Value"></param>
public record struct LLamaSeqId(int Value)
{
public int Value;

public LLamaSeqId(int value)
{
Value = value;
}
/// <summary>
/// The raw value
/// </summary>
public int Value = Value;

/// <summary>
/// Convert a LLamaSeqId into an integer (extract the raw value)
/// </summary>
/// <param name="pos"></param>
public static explicit operator int(LLamaSeqId pos) => pos.Value;

/// <summary>
/// Convert an integer into a LLamaSeqId
/// </summary>
/// <param name="value"></param>
/// <returns></returns>
public static explicit operator LLamaSeqId(int value) => new(value);
}

+ 24
- 24
LLama/Native/LLamaTokenData.cs View File

@@ -1,28 +1,28 @@
using System.Runtime.InteropServices;

namespace LLama.Native
namespace LLama.Native;

/// <summary>
/// A single token along with probability of this token being selected
/// </summary>
/// <param name="id"></param>
/// <param name="logit"></param>
/// <param name="p"></param>
[StructLayout(LayoutKind.Sequential)]
public record struct LLamaTokenData(int id, float logit, float p)
{
[StructLayout(LayoutKind.Sequential)]
public struct LLamaTokenData
{
/// <summary>
/// token id
/// </summary>
public int id;
/// <summary>
/// log-odds of the token
/// </summary>
public float logit;
/// <summary>
/// probability of the token
/// </summary>
public float p;
/// <summary>
/// token id
/// </summary>
public int id = id;

/// <summary>
/// log-odds of the token
/// </summary>
public float logit = logit;

public LLamaTokenData(int id, float logit, float p)
{
this.id = id;
this.logit = logit;
this.p = p;
}
}
}
/// <summary>
/// probability of the token
/// </summary>
public float p = p;
}

+ 9
- 3
LLama/Native/SafeLLamaContextHandle.cs View File

@@ -1,7 +1,5 @@
using System;
using System.Buffers;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
using LLama.Exceptions;

@@ -212,9 +210,17 @@ namespace LLama.Native
}
}

/// <summary>
/// </summary>
/// <param name="batch"></param>
/// <returns>Positive return values does not mean a fatal error, but rather a warning:<br />
/// - 0: success<br />
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
/// - &lt; 0: error<br />
/// </returns>
public int Decode(LLamaBatchSafeHandle batch)
{
return NativeApi.llama_decode(this, batch.Batch);
return NativeApi.llama_decode(this, batch.NativeBatch);
}

#region state


+ 2
- 2
LLama/Native/SafeLlamaModelHandle.cs View File

@@ -84,14 +84,14 @@ namespace LLama.Native
/// adapter. Can be NULL to use the current loaded model.</param>
/// <param name="threads"></param>
/// <exception cref="RuntimeError"></exception>
public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null, uint? threads = null)
public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null, int? threads = null)
{
var err = NativeApi.llama_model_apply_lora_from_file(
this,
lora,
scale,
string.IsNullOrEmpty(modelBase) ? null : modelBase,
(int?)threads ?? -1
threads ?? Math.Max(1, Environment.ProcessorCount / 2)
);

if (err != 0)


Loading…
Cancel
Save