diff --git a/LLama/Abstractions/IModelParams.cs b/LLama/Abstractions/IModelParams.cs
index 56b7baee..fada91a1 100644
--- a/LLama/Abstractions/IModelParams.cs
+++ b/LLama/Abstractions/IModelParams.cs
@@ -5,7 +5,6 @@ using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
-using LLama.Common;
using LLama.Native;
namespace LLama.Abstractions
@@ -110,6 +109,7 @@ namespace LLama.Abstractions
}
}
+
///
/// A fixed size array to set the tensor splits across multiple GPUs
///
@@ -204,6 +204,7 @@ namespace LLama.Abstractions
}
}
+
///
/// An override for a single key/value pair in model metadata
///
@@ -243,57 +244,92 @@ namespace LLama.Abstractions
return new BoolOverride(key, value);
}
- internal abstract void Write(ref LLamaModelMetadataOverride dest);
-
///
/// Get the key being overriden by this override
///
public abstract string Key { get; init; }
+ internal abstract LLamaModelKvOverrideType Type { get; }
+
+ internal abstract void WriteValue(ref LLamaModelMetadataOverride dest);
+
+ internal abstract void WriteValue(Utf8JsonWriter writer, JsonSerializerOptions options);
+
private record IntOverride(string Key, int Value) : MetadataOverride
{
- internal override void Write(ref LLamaModelMetadataOverride dest)
+ internal override LLamaModelKvOverrideType Type => LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT;
+
+ internal override void WriteValue(ref LLamaModelMetadataOverride dest)
{
- dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT;
dest.IntValue = Value;
}
+
+ internal override void WriteValue(Utf8JsonWriter writer, JsonSerializerOptions options)
+ {
+ writer.WriteNumberValue(Value);
+ }
}
private record FloatOverride(string Key, float Value) : MetadataOverride
{
- internal override void Write(ref LLamaModelMetadataOverride dest)
+ internal override LLamaModelKvOverrideType Type => LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT;
+
+ internal override void WriteValue(ref LLamaModelMetadataOverride dest)
{
- dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT;
dest.FloatValue = Value;
}
+
+ internal override void WriteValue(Utf8JsonWriter writer, JsonSerializerOptions options)
+ {
+ writer.WriteNumberValue(Value);
+ }
}
private record BoolOverride(string Key, bool Value) : MetadataOverride
{
- internal override void Write(ref LLamaModelMetadataOverride dest)
+ internal override LLamaModelKvOverrideType Type => LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL;
+
+ internal override void WriteValue(ref LLamaModelMetadataOverride dest)
{
- dest.Tag = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL;
dest.BoolValue = Value ? -1 : 0;
}
+
+ internal override void WriteValue(Utf8JsonWriter writer, JsonSerializerOptions options)
+ {
+ writer.WriteBooleanValue(Value);
+ }
}
}
+ ///
+ /// A JSON converter for
+ ///
public class MetadataOverrideConverter
: JsonConverter
{
+ ///
+ public override bool CanConvert(Type typeToConvert)
+ {
+ return typeof(MetadataOverride).IsAssignableFrom(typeToConvert);
+ }
+
///
public override MetadataOverride Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
- throw new NotImplementedException();
- //var arr = JsonSerializer.Deserialize(ref reader, options) ?? Array.Empty();
- //return new TensorSplitsCollection(arr);
+ throw new NotImplementedException("for some reason this is never called!");
}
///
public override void Write(Utf8JsonWriter writer, MetadataOverride value, JsonSerializerOptions options)
{
- throw new NotImplementedException();
- //JsonSerializer.Serialize(writer, value.Splits, options);
+ writer.WriteStartObject();
+ {
+ writer.WriteString("Key", value.Key);
+ writer.WriteNumber("Type", (int)value.Type);
+ writer.WritePropertyName("Value");
+ value.WriteValue(writer, options);
+ }
+ writer.WriteEndObject();
}
}
}
\ No newline at end of file
diff --git a/LLama/Extensions/IModelParamsExtensions.cs b/LLama/Extensions/IModelParamsExtensions.cs
index f1a9dea9..08805d32 100644
--- a/LLama/Extensions/IModelParamsExtensions.cs
+++ b/LLama/Extensions/IModelParamsExtensions.cs
@@ -57,10 +57,12 @@ public static class IModelParamsExtensions
for (var i = 0; i < @params.MetadataOverrides.Count; i++)
{
var item = @params.MetadataOverrides[i];
- var native = new LLamaModelMetadataOverride();
+ var native = new LLamaModelMetadataOverride
+ {
+ Tag = item.Type
+ };
- // Init value and tag
- item.Write(ref native);
+ item.WriteValue(ref native);
// Convert key to bytes
unsafe