| @@ -11,13 +11,13 @@ | |||
| </PropertyGroup> | |||
| <ItemGroup> | |||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.3.2" /> | |||
| <PackageReference Include="xunit" Version="2.4.2" /> | |||
| <PackageReference Include="xunit.runner.visualstudio" Version="2.4.5"> | |||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.7.1" /> | |||
| <PackageReference Include="xunit" Version="2.5.0" /> | |||
| <PackageReference Include="xunit.runner.visualstudio" Version="2.5.0"> | |||
| <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | |||
| <PrivateAssets>all</PrivateAssets> | |||
| </PackageReference> | |||
| <PackageReference Include="coverlet.collector" Version="3.1.2"> | |||
| <PackageReference Include="coverlet.collector" Version="6.0.0"> | |||
| <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | |||
| <PrivateAssets>all</PrivateAssets> | |||
| </PackageReference> | |||
| @@ -1,4 +1,6 @@ | |||
| using LLama.Common; | |||
| using System.Text; | |||
| using LLama.Common; | |||
| using Newtonsoft.Json; | |||
| namespace LLama.Unittest | |||
| { | |||
| @@ -17,12 +19,33 @@ namespace LLama.Unittest | |||
| GpuLayerCount = 111 | |||
| }; | |||
| var json = System.Text.Json.JsonSerializer.Serialize(expected); | |||
| var actual = System.Text.Json.JsonSerializer.Deserialize<ModelParams>(json); | |||
| var options = new System.Text.Json.JsonSerializerOptions(); | |||
| options.Converters.Add(new SystemTextJsonEncodingConverter()); | |||
| var json = System.Text.Json.JsonSerializer.Serialize(expected, options); | |||
| var actual = System.Text.Json.JsonSerializer.Deserialize<ModelParams>(json, options); | |||
| Assert.Equal(expected, actual); | |||
| } | |||
| private class SystemTextJsonEncodingConverter | |||
| : System.Text.Json.Serialization.JsonConverter<Encoding> | |||
| { | |||
| public override Encoding? Read(ref System.Text.Json.Utf8JsonReader reader, Type typeToConvert, System.Text.Json.JsonSerializerOptions options) | |||
| { | |||
| var name = reader.GetString(); | |||
| if (name == null) | |||
| return null; | |||
| return Encoding.GetEncoding(name); | |||
| } | |||
| public override void Write(System.Text.Json.Utf8JsonWriter writer, Encoding value, System.Text.Json.JsonSerializerOptions options) | |||
| { | |||
| writer.WriteStringValue(value.WebName); | |||
| } | |||
| } | |||
| [Fact] | |||
| public void SerializeRoundTripNewtonsoft() | |||
| { | |||
| @@ -36,10 +59,30 @@ namespace LLama.Unittest | |||
| GpuLayerCount = 111 | |||
| }; | |||
| var json = Newtonsoft.Json.JsonConvert.SerializeObject(expected); | |||
| var actual = Newtonsoft.Json.JsonConvert.DeserializeObject<ModelParams>(json); | |||
| var settings = new Newtonsoft.Json.JsonSerializerSettings(); | |||
| settings.Converters.Add(new NewtsonsoftEncodingConverter()); | |||
| var json = Newtonsoft.Json.JsonConvert.SerializeObject(expected, settings); | |||
| var actual = Newtonsoft.Json.JsonConvert.DeserializeObject<ModelParams>(json, settings); | |||
| Assert.Equal(expected, actual); | |||
| } | |||
| private class NewtsonsoftEncodingConverter | |||
| : Newtonsoft.Json.JsonConverter<Encoding> | |||
| { | |||
| public override void WriteJson(JsonWriter writer, Encoding? value, JsonSerializer serializer) | |||
| { | |||
| writer.WriteValue((string?)value?.WebName); | |||
| } | |||
| public override Encoding? ReadJson(JsonReader reader, Type objectType, Encoding? existingValue, bool hasExistingValue, JsonSerializer serializer) | |||
| { | |||
| var name = (string?)reader.Value; | |||
| if (name == null) | |||
| return null; | |||
| return Encoding.GetEncoding(name); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -1,4 +1,5 @@ | |||
| using LLama.Abstractions; | |||
| using System.Text; | |||
| using LLama.Abstractions; | |||
| namespace LLama.Web.Common | |||
| { | |||
| @@ -115,6 +116,6 @@ namespace LLama.Web.Common | |||
| /// <summary> | |||
| /// The encoding to use for models | |||
| /// </summary> | |||
| public string Encoding { get; set; } = "UTF-8"; | |||
| public Encoding Encoding { get; set; } = Encoding.UTF8; | |||
| } | |||
| } | |||
| @@ -1,4 +1,6 @@ | |||
| namespace LLama.Abstractions | |||
| using System.Text; | |||
| namespace LLama.Abstractions | |||
| { | |||
| public interface IModelParams | |||
| { | |||
| @@ -121,6 +123,6 @@ | |||
| /// <summary> | |||
| /// The encoding to use for models | |||
| /// </summary> | |||
| string Encoding { get; set; } | |||
| Encoding Encoding { get; set; } | |||
| } | |||
| } | |||
| @@ -1,5 +1,6 @@ | |||
| using LLama.Abstractions; | |||
| using System; | |||
| using System.Text; | |||
| namespace LLama.Common | |||
| { | |||
| @@ -114,7 +115,7 @@ namespace LLama.Common | |||
| /// <summary> | |||
| /// The encoding to use to convert text for the model | |||
| /// </summary> | |||
| public string Encoding { get; set; } = "UTF-8"; | |||
| public Encoding Encoding { get; set; } = Encoding.UTF8; | |||
| /// <summary> | |||
| /// | |||
| @@ -183,7 +184,7 @@ namespace LLama.Common | |||
| RopeFrequencyBase = ropeFrequencyBase; | |||
| RopeFrequencyScale = ropeFrequencyScale; | |||
| MulMatQ = mulMatQ; | |||
| Encoding = encoding; | |||
| Encoding = Encoding.GetEncoding(encoding); | |||
| } | |||
| } | |||
| } | |||
| @@ -68,7 +68,7 @@ namespace LLama | |||
| Params = @params; | |||
| _logger = logger; | |||
| _encoding = Encoding.GetEncoding(@params.Encoding); | |||
| _encoding = @params.Encoding; | |||
| _logger?.Log(nameof(LLamaContext), $"Initializing LLama model with params: {this.Params}", ILLamaLogger.LogLevel.Info); | |||
| _ctx = Utils.InitLLamaContextFromModelParams(Params); | |||
| @@ -79,7 +79,7 @@ namespace LLama | |||
| Params = @params; | |||
| _logger = logger; | |||
| _encoding = Encoding.GetEncoding(@params.Encoding); | |||
| _encoding = @params.Encoding; | |||
| _ctx = nativeContext; | |||
| } | |||
| @@ -98,7 +98,7 @@ namespace LLama | |||
| Params = @params; | |||
| _logger = logger; | |||
| _encoding = Encoding.GetEncoding(@params.Encoding); | |||
| _encoding = @params.Encoding; | |||
| using var pin = @params.ToLlamaContextParams(out var lparams); | |||
| _ctx = SafeLLamaContextHandle.Create(model.NativeHandle, lparams); | |||
| @@ -47,7 +47,7 @@ namespace LLama | |||
| [Obsolete("Use the constructor which automatically creates contexts using the LLamaWeights")] | |||
| public StatelessExecutor(LLamaContext context) | |||
| { | |||
| _weights = new LLamaWeights(context.NativeHandle.ModelHandle, Encoding.GetEncoding(context.Params.Encoding)); | |||
| _weights = new LLamaWeights(context.NativeHandle.ModelHandle, context.Params.Encoding); | |||
| _params = context.Params; | |||
| Context = _weights.CreateContext(_params); | |||
| @@ -59,7 +59,7 @@ namespace LLama | |||
| if (!string.IsNullOrEmpty(@params.LoraAdapter)) | |||
| weights.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads); | |||
| return new LLamaWeights(weights, Encoding.GetEncoding(@params.Encoding)); | |||
| return new LLamaWeights(weights, @params.Encoding); | |||
| } | |||
| /// <inheritdoc /> | |||