diff --git a/LLama.Unittest/BasicTest.cs b/LLama.Unittest/BasicTest.cs index 2c06dd47..11676617 100644 --- a/LLama.Unittest/BasicTest.cs +++ b/LLama.Unittest/BasicTest.cs @@ -38,6 +38,8 @@ namespace LLama.Unittest [Fact] public void AdvancedModelProperties() { + // These are the keys in the llama 7B test model. This will need changing if + // tests are switched to use a new model! var expected = new Dictionary { { "general.name", "LLaMA v2" }, @@ -60,31 +62,16 @@ namespace LLama.Unittest { "tokenizer.ggml.unknown_token_id", "0" }, }; - var metaCount = NativeApi.llama_model_meta_count(_model.NativeHandle); - Assert.Equal(expected.Count, metaCount); + // Print all keys + foreach (var (key, value) in _model.Metadata) + _testOutputHelper.WriteLine($"{key} = {value}"); - Span buffer = stackalloc byte[128]; - for (var i = 0; i < expected.Count; i++) - { - unsafe - { - fixed (byte* ptr = buffer) - { - var length = NativeApi.llama_model_meta_key_by_index(_model.NativeHandle, i, ptr, 128); - Assert.True(length > 0); - var key = Encoding.UTF8.GetString(buffer[..length]); - - length = NativeApi.llama_model_meta_val_str_by_index(_model.NativeHandle, i, ptr, 128); - Assert.True(length > 0); - var val = Encoding.UTF8.GetString(buffer[..length]); - - _testOutputHelper.WriteLine($"{key} == {val}"); + // Check the count is equal + Assert.Equal(expected.Count, _model.Metadata.Count); - Assert.True(expected.ContainsKey(key)); - Assert.Equal(expected[key], val); - } - } - } + // Check every key + foreach (var (key, value) in _model.Metadata) + Assert.Equal(expected[key], value); } } } \ No newline at end of file diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index 7ae104a5..54760cb8 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -1,4 +1,7 @@ using System; +using System.Buffers; +using System.Collections.Generic; +using System.Text; using LLama.Abstractions; using LLama.Extensions; using LLama.Native; @@ -58,9 +61,15 @@ namespace LLama /// public int EmbeddingSize => NativeHandle.EmbeddingSize; + /// + /// All metadata keys in this model + /// + public IReadOnlyDictionary Metadata { get; set; } + internal LLamaWeights(SafeLlamaModelHandle weights) { NativeHandle = weights; + Metadata = weights.ReadMetadata(); } /// diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index b93c2b89..a3e66bbb 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -1,5 +1,7 @@ using System; +using System.Buffers; using System.Collections.Generic; +using System.Security.Cryptography; using System.Text; using LLama.Exceptions; @@ -36,6 +38,12 @@ namespace LLama.Native /// public ulong ParameterCount { get; } + /// + /// Get the number of metadata key/value pairs + /// + /// + public int MetadataCount { get; } + internal SafeLlamaModelHandle(IntPtr handle) : base(handle) { @@ -44,6 +52,7 @@ namespace LLama.Native EmbeddingSize = NativeApi.llama_n_embd(this); SizeInBytes = NativeApi.llama_model_size(this); ParameterCount = NativeApi.llama_model_n_params(this); + MetadataCount = NativeApi.llama_model_meta_count(this); } /// @@ -199,5 +208,75 @@ namespace LLama.Native return SafeLLamaContextHandle.Create(this, @params); } #endregion + + #region metadata + /// + /// Get the metadata key for the given index + /// + /// The index to get + /// A temporary buffer to store key characters in. Must be large enough to contain the key. + /// The key, null if there is no such key or if the buffer was too small + public Memory? MetadataKeyByIndex(int index, Memory buffer) + { + unsafe + { + using var pin = buffer.Pin(); + var keyLength = NativeApi.llama_model_meta_key_by_index(this, index, (byte*)pin.Pointer, 1024); + if (keyLength < 0) + return null; + return buffer.Slice(keyLength); + } + } + + /// + /// Get the metadata value for the given index + /// + /// The index to get + /// A temporary buffer to store value characters in. Must be large enough to contain the value. + /// The value, null if there is no such value or if the buffer was too small + public Memory? MetadataValueByIndex(int index, Memory buffer) + { + unsafe + { + using var pin = buffer.Pin(); + var keyLength = NativeApi.llama_model_meta_key_by_index(this, index, (byte*)pin.Pointer, 1024); + if (keyLength < 0) + return null; + return buffer.Slice(keyLength); + } + } + + internal IReadOnlyDictionary ReadMetadata() + { + var result = new Dictionary(); + + var dest = ArrayPool.Shared.Rent(1024); + try + { + for (var i = 0; i < MetadataCount; i++) + { + Array.Clear(dest, 0, dest.Length); + + var keyBytes = MetadataKeyByIndex(i, dest.AsMemory()); + if (keyBytes == null) + continue; + var key = Encoding.UTF8.GetString(keyBytes.Value.Span); + + var valBytes = MetadataValueByIndex(i, dest.AsMemory()); + if (valBytes == null) + continue; + var val = Encoding.UTF8.GetString(valBytes.Value.Span); + + result[key] = val; + } + } + finally + { + ArrayPool.Shared.Return(dest); + } + + return result; + } + #endregion } }