|
|
|
@@ -3,6 +3,9 @@ using System.Buffers; |
|
|
|
using System.Collections; |
|
|
|
using System.Collections.Generic; |
|
|
|
using System.Linq; |
|
|
|
using System.Text.Json; |
|
|
|
using System.Text.Json.Serialization; |
|
|
|
using LLama.Common; |
|
|
|
using LLama.Native; |
|
|
|
|
|
|
|
namespace LLama.Abstractions |
|
|
|
@@ -105,6 +108,7 @@ namespace LLama.Abstractions |
|
|
|
/// <summary> |
|
|
|
/// A fixed size array to set the tensor splits across multiple GPUs |
|
|
|
/// </summary> |
|
|
|
[JsonConverter(typeof(TensorSplitsCollectionConverter))] |
|
|
|
public sealed class TensorSplitsCollection |
|
|
|
: IEnumerable<float> |
|
|
|
{ |
|
|
|
@@ -174,4 +178,24 @@ namespace LLama.Abstractions |
|
|
|
} |
|
|
|
#endregion |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// A JSON converter for <see cref="TensorSplitsCollection"/> |
|
|
|
/// </summary> |
|
|
|
public class TensorSplitsCollectionConverter |
|
|
|
: JsonConverter<TensorSplitsCollection> |
|
|
|
{ |
|
|
|
/// <inheritdoc/> |
|
|
|
public override TensorSplitsCollection? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) |
|
|
|
{ |
|
|
|
var arr = JsonSerializer.Deserialize<float[]>(ref reader, options) ?? Array.Empty<float>(); |
|
|
|
return new TensorSplitsCollection(arr); |
|
|
|
} |
|
|
|
|
|
|
|
/// <inheritdoc/> |
|
|
|
public override void Write(Utf8JsonWriter writer, TensorSplitsCollection value, JsonSerializerOptions options) |
|
|
|
{ |
|
|
|
JsonSerializer.Serialize(writer, value.Splits, options); |
|
|
|
} |
|
|
|
} |
|
|
|
} |