Browse Source

Move JSON converter for TensorSplitsCollection

tags/0.9.1
xbotter 2 years ago
parent
commit
340bbbcf48
No known key found for this signature in database GPG Key ID: A3F32F44E9F160E1
2 changed files with 24 additions and 17 deletions
  1. +24
    -0
      LLama/Abstractions/IModelParams.cs
  2. +0
    -17
      LLama/Common/ModelParams.cs

+ 24
- 0
LLama/Abstractions/IModelParams.cs View File

@@ -3,6 +3,9 @@ using System.Buffers;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
using LLama.Common;
using LLama.Native;

namespace LLama.Abstractions
@@ -105,6 +108,7 @@ namespace LLama.Abstractions
/// <summary>
/// A fixed size array to set the tensor splits across multiple GPUs
/// </summary>
[JsonConverter(typeof(TensorSplitsCollectionConverter))]
public sealed class TensorSplitsCollection
: IEnumerable<float>
{
@@ -174,4 +178,24 @@ namespace LLama.Abstractions
}
#endregion
}

/// <summary>
/// A JSON converter for <see cref="TensorSplitsCollection"/>
/// </summary>
public class TensorSplitsCollectionConverter
: JsonConverter<TensorSplitsCollection>
{
/// <inheritdoc/>
public override TensorSplitsCollection? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>();
return new TensorSplitsCollection(arr);
}

/// <inheritdoc/>
public override void Write(Utf8JsonWriter writer, TensorSplitsCollection value, JsonSerializerOptions options)
{
JsonSerializer.Serialize(writer, value.Splits, options);
}
}
}

+ 0
- 17
LLama/Common/ModelParams.cs View File

@@ -59,7 +59,6 @@ namespace LLama.Common
public bool EmbeddingMode { get; set; }

/// <inheritdoc />
[JsonConverter(typeof(TensorSplitsCollectionConverter))]
public TensorSplitsCollection TensorSplits { get; set; } = new();

/// <inheritdoc />
@@ -123,20 +122,4 @@ namespace LLama.Common
ModelPath = "";
}
}


internal class TensorSplitsCollectionConverter
: JsonConverter<TensorSplitsCollection>
{
public override TensorSplitsCollection? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>();
return new TensorSplitsCollection(arr);
}

public override void Write(Utf8JsonWriter writer, TensorSplitsCollection value, JsonSerializerOptions options)
{
JsonSerializer.Serialize(writer, value.Splits, options);
}
}
}

Loading…
Cancel
Save