Browse Source

Switched to properly typed `Encoding` property

tags/v0.5.1
Martin Evans 2 years ago
parent
commit
93f24f8a51
8 changed files with 67 additions and 20 deletions
  1. +4
    -4
      LLama.Unittest/LLama.Unittest.csproj
  2. +48
    -5
      LLama.Unittest/ModelsParamsTests.cs
  3. +3
    -2
      LLama.Web/Common/ModelOptions.cs
  4. +4
    -2
      LLama/Abstractions/IModelParams.cs
  5. +3
    -2
      LLama/Common/ModelParams.cs
  6. +3
    -3
      LLama/LLamaContext.cs
  7. +1
    -1
      LLama/LLamaStatelessExecutor.cs
  8. +1
    -1
      LLama/LLamaWeights.cs

+ 4
- 4
LLama.Unittest/LLama.Unittest.csproj View File

@@ -11,13 +11,13 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.3.2" />
<PackageReference Include="xunit" Version="2.4.2" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.4.5">
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.7.1" />
<PackageReference Include="xunit" Version="2.5.0" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.5.0">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>
<PackageReference Include="coverlet.collector" Version="3.1.2">
<PackageReference Include="coverlet.collector" Version="6.0.0">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>


+ 48
- 5
LLama.Unittest/ModelsParamsTests.cs View File

@@ -1,4 +1,6 @@
using LLama.Common;
using System.Text;
using LLama.Common;
using Newtonsoft.Json;

namespace LLama.Unittest
{
@@ -17,12 +19,33 @@ namespace LLama.Unittest
GpuLayerCount = 111
};

var json = System.Text.Json.JsonSerializer.Serialize(expected);
var actual = System.Text.Json.JsonSerializer.Deserialize<ModelParams>(json);
var options = new System.Text.Json.JsonSerializerOptions();
options.Converters.Add(new SystemTextJsonEncodingConverter());

var json = System.Text.Json.JsonSerializer.Serialize(expected, options);
var actual = System.Text.Json.JsonSerializer.Deserialize<ModelParams>(json, options);

Assert.Equal(expected, actual);
}

private class SystemTextJsonEncodingConverter
: System.Text.Json.Serialization.JsonConverter<Encoding>

{
public override Encoding? Read(ref System.Text.Json.Utf8JsonReader reader, Type typeToConvert, System.Text.Json.JsonSerializerOptions options)
{
var name = reader.GetString();
if (name == null)
return null;
return Encoding.GetEncoding(name);
}

public override void Write(System.Text.Json.Utf8JsonWriter writer, Encoding value, System.Text.Json.JsonSerializerOptions options)
{
writer.WriteStringValue(value.WebName);
}
}

[Fact]
public void SerializeRoundTripNewtonsoft()
{
@@ -36,10 +59,30 @@ namespace LLama.Unittest
GpuLayerCount = 111
};

var json = Newtonsoft.Json.JsonConvert.SerializeObject(expected);
var actual = Newtonsoft.Json.JsonConvert.DeserializeObject<ModelParams>(json);
var settings = new Newtonsoft.Json.JsonSerializerSettings();
settings.Converters.Add(new NewtsonsoftEncodingConverter());

var json = Newtonsoft.Json.JsonConvert.SerializeObject(expected, settings);
var actual = Newtonsoft.Json.JsonConvert.DeserializeObject<ModelParams>(json, settings);

Assert.Equal(expected, actual);
}

private class NewtsonsoftEncodingConverter
: Newtonsoft.Json.JsonConverter<Encoding>
{
public override void WriteJson(JsonWriter writer, Encoding? value, JsonSerializer serializer)
{
writer.WriteValue((string?)value?.WebName);
}

public override Encoding? ReadJson(JsonReader reader, Type objectType, Encoding? existingValue, bool hasExistingValue, JsonSerializer serializer)
{
var name = (string?)reader.Value;
if (name == null)
return null;
return Encoding.GetEncoding(name);
}
}
}
}

+ 3
- 2
LLama.Web/Common/ModelOptions.cs View File

@@ -1,4 +1,5 @@
using LLama.Abstractions;
using System.Text;
using LLama.Abstractions;

namespace LLama.Web.Common
{
@@ -115,6 +116,6 @@ namespace LLama.Web.Common
/// <summary>
/// The encoding to use for models
/// </summary>
public string Encoding { get; set; } = "UTF-8";
public Encoding Encoding { get; set; } = Encoding.UTF8;
}
}

+ 4
- 2
LLama/Abstractions/IModelParams.cs View File

@@ -1,4 +1,6 @@
namespace LLama.Abstractions
using System.Text;

namespace LLama.Abstractions
{
public interface IModelParams
{
@@ -121,6 +123,6 @@
/// <summary>
/// The encoding to use for models
/// </summary>
string Encoding { get; set; }
Encoding Encoding { get; set; }
}
}

+ 3
- 2
LLama/Common/ModelParams.cs View File

@@ -1,5 +1,6 @@
using LLama.Abstractions;
using System;
using System.Text;

namespace LLama.Common
{
@@ -114,7 +115,7 @@ namespace LLama.Common
/// <summary>
/// The encoding to use to convert text for the model
/// </summary>
public string Encoding { get; set; } = "UTF-8";
public Encoding Encoding { get; set; } = Encoding.UTF8;

/// <summary>
///
@@ -183,7 +184,7 @@ namespace LLama.Common
RopeFrequencyBase = ropeFrequencyBase;
RopeFrequencyScale = ropeFrequencyScale;
MulMatQ = mulMatQ;
Encoding = encoding;
Encoding = Encoding.GetEncoding(encoding);
}
}
}

+ 3
- 3
LLama/LLamaContext.cs View File

@@ -68,7 +68,7 @@ namespace LLama
Params = @params;

_logger = logger;
_encoding = Encoding.GetEncoding(@params.Encoding);
_encoding = @params.Encoding;

_logger?.Log(nameof(LLamaContext), $"Initializing LLama model with params: {this.Params}", ILLamaLogger.LogLevel.Info);
_ctx = Utils.InitLLamaContextFromModelParams(Params);
@@ -79,7 +79,7 @@ namespace LLama
Params = @params;

_logger = logger;
_encoding = Encoding.GetEncoding(@params.Encoding);
_encoding = @params.Encoding;
_ctx = nativeContext;
}

@@ -98,7 +98,7 @@ namespace LLama
Params = @params;

_logger = logger;
_encoding = Encoding.GetEncoding(@params.Encoding);
_encoding = @params.Encoding;

using var pin = @params.ToLlamaContextParams(out var lparams);
_ctx = SafeLLamaContextHandle.Create(model.NativeHandle, lparams);


+ 1
- 1
LLama/LLamaStatelessExecutor.cs View File

@@ -47,7 +47,7 @@ namespace LLama
[Obsolete("Use the constructor which automatically creates contexts using the LLamaWeights")]
public StatelessExecutor(LLamaContext context)
{
_weights = new LLamaWeights(context.NativeHandle.ModelHandle, Encoding.GetEncoding(context.Params.Encoding));
_weights = new LLamaWeights(context.NativeHandle.ModelHandle, context.Params.Encoding);
_params = context.Params;

Context = _weights.CreateContext(_params);


+ 1
- 1
LLama/LLamaWeights.cs View File

@@ -59,7 +59,7 @@ namespace LLama
if (!string.IsNullOrEmpty(@params.LoraAdapter))
weights.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads);

return new LLamaWeights(weights, Encoding.GetEncoding(@params.Encoding));
return new LLamaWeights(weights, @params.Encoding);
}

/// <inheritdoc />


Loading…
Cancel
Save