| @@ -22,5 +22,12 @@ namespace Tensorflow.Common.Extensions | |||
| { | |||
| return new Tensors(tensors); | |||
| } | |||
| public static void Deconstruct<T1, T2, T3>(this (T1, T2, T3) values, out T1 first, out T2 second, out T3 third) | |||
| { | |||
| first = values.Item1; | |||
| second = values.Item2; | |||
| third = values.Item3; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,33 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Common.Extensions | |||
| { | |||
| public static class NestExtensions | |||
| { | |||
| public static Tensors ToTensors(this INestable<Tensor> tensors) | |||
| { | |||
| return new Tensors(tensors.AsNest()); | |||
| } | |||
| public static Tensors? ToTensors(this Nest<Tensor> tensors) | |||
| { | |||
| return Tensors.FromNest(tensors); | |||
| } | |||
| /// <summary> | |||
| /// If the nested object is already a nested type, this function could reduce it. | |||
| /// For example, `Nest[Nest[T]]` can be reduced to `Nest[T]`. | |||
| /// </summary> | |||
| /// <typeparam name="TIn"></typeparam> | |||
| /// <typeparam name="TOut"></typeparam> | |||
| /// <param name="input"></param> | |||
| /// <returns></returns> | |||
| public static Nest<TOut> ReduceTo<TIn, TOut>(this INestStructure<TIn> input) where TIn: INestStructure<TOut> | |||
| { | |||
| return Nest<TOut>.ReduceFrom(input); | |||
| } | |||
| } | |||
| } | |||
| @@ -5,7 +5,7 @@ using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| public class GeneralizedTensorShape: IEnumerable<long?[]> | |||
| public class GeneralizedTensorShape: IEnumerable<long?[]>, INestStructure<long?>, INestable<long?> | |||
| { | |||
| public TensorShapeConfig[] Shapes { get; set; } | |||
| /// <summary> | |||
| @@ -63,6 +63,57 @@ namespace Tensorflow.Common.Types | |||
| return Shapes.Select(x => new Shape(x.Items.Select(y => y is null ? -1 : y.Value).ToArray())).ToArray(); | |||
| } | |||
| public IEnumerable<long?> Flatten() | |||
| { | |||
| List<long?> result = new List<long?>(); | |||
| foreach(var shapeConfig in Shapes) | |||
| { | |||
| result.AddRange(shapeConfig.Items); | |||
| } | |||
| return result; | |||
| } | |||
| public INestStructure<TOut> MapStructure<TOut>(Func<long?, TOut> func) | |||
| { | |||
| List<Nest<TOut>> lists = new(); | |||
| foreach(var shapeConfig in Shapes) | |||
| { | |||
| lists.Add(new Nest<TOut>(shapeConfig.Items.Select(x => new Nest<TOut>(func(x))))); | |||
| } | |||
| return new Nest<TOut>(lists); | |||
| } | |||
| public Nest<long?> AsNest() | |||
| { | |||
| Nest<long?> DealWithSingleShape(TensorShapeConfig config) | |||
| { | |||
| if (config.Items.Length == 0) | |||
| { | |||
| return Nest<long?>.Empty; | |||
| } | |||
| else if (config.Items.Length == 1) | |||
| { | |||
| return new Nest<long?>(config.Items[0]); | |||
| } | |||
| else | |||
| { | |||
| return new Nest<long?>(config.Items.Select(x => new Nest<long?>(x))); | |||
| } | |||
| } | |||
| if(Shapes.Length == 0) | |||
| { | |||
| return Nest<long?>.Empty; | |||
| } | |||
| else if(Shapes.Length == 1) | |||
| { | |||
| return DealWithSingleShape(Shapes[0]); | |||
| } | |||
| else | |||
| { | |||
| return new Nest<long?>(Shapes.Select(s => DealWithSingleShape(s))); | |||
| } | |||
| } | |||
| public IEnumerator<long?[]> GetEnumerator() | |||
| { | |||
| foreach (var shape in Shapes) | |||
| @@ -0,0 +1,27 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| /// <summary> | |||
| /// This interface indicates that a class may have a nested structure and provide | |||
| /// methods to manipulate with the structure. | |||
| /// </summary> | |||
| public interface INestStructure<T>: INestable<T> | |||
| { | |||
| /// <summary> | |||
| /// Flatten the Nestable object. Node that if the object contains only one value, | |||
| /// it will be flattened to an enumerable with one element. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| IEnumerable<T> Flatten(); | |||
| /// <summary> | |||
| /// Construct a new object with the same nested structure. | |||
| /// </summary> | |||
| /// <typeparam name="TOut"></typeparam> | |||
| /// <param name="func"></param> | |||
| /// <returns></returns> | |||
| INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func); | |||
| } | |||
| } | |||
| @@ -0,0 +1,11 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| public interface INestable<T> | |||
| { | |||
| Nest<T> AsNest(); | |||
| } | |||
| } | |||
| @@ -0,0 +1,62 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| public static class Nest | |||
| { | |||
| /// <summary> | |||
| /// Pack the flat items to a nested sequence by the template. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| /// <param name="template"></param> | |||
| /// <param name="flatItems"></param> | |||
| /// <returns></returns> | |||
| public static Nest<T> PackSequenceAs<T>(INestable<T> template, T[] flatItems) | |||
| { | |||
| return template.AsNest().PackSequence(flatItems); | |||
| } | |||
| /// <summary> | |||
| /// Pack the flat items to a nested sequence by the template. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| /// <param name="template"></param> | |||
| /// <param name="flatItems"></param> | |||
| /// <returns></returns> | |||
| public static Nest<T> PackSequenceAs<T>(INestable<T> template, List<T> flatItems) | |||
| { | |||
| return template.AsNest().PackSequence(flatItems.ToArray()); | |||
| } | |||
| /// <summary> | |||
| /// Flatten the nested object. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| /// <param name="nestedObject"></param> | |||
| /// <returns></returns> | |||
| public static IEnumerable<T> Flatten<T>(INestable<T> nestedObject) | |||
| { | |||
| return nestedObject.AsNest().Flatten(); | |||
| } | |||
| /// <summary> | |||
| /// Map the structure with specified function. | |||
| /// </summary> | |||
| /// <typeparam name="TIn"></typeparam> | |||
| /// <typeparam name="TOut"></typeparam> | |||
| /// <param name="func"></param> | |||
| /// <param name="nestedObject"></param> | |||
| /// <returns></returns> | |||
| public static INestStructure<TOut> MapStructure<TIn, TOut>(Func<TIn, TOut> func, INestable<TIn> nestedObject) | |||
| { | |||
| return nestedObject.AsNest().MapStructure(func); | |||
| } | |||
| public static bool IsNested<T>(INestable<T> obj) | |||
| { | |||
| return obj.AsNest().IsNested(); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,458 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Common.Extensions; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| public enum NestType | |||
| { | |||
| Empty, | |||
| Node, | |||
| List, | |||
| Dictionary | |||
| } | |||
| /// <summary> | |||
| /// A nested structure which may inclulde value, list and dictionary. | |||
| /// Note that dictionary does not ensure the data order. When using it as IEnumerable, | |||
| /// its order is depth-first. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| public class Nest<T> : INestStructure<T>, IEnumerable<T> | |||
| { | |||
| private static readonly Nest<T> _empty = new Nest<T>() | |||
| { | |||
| NestType = NestType.Empty, | |||
| }; | |||
| public static Nest<T> Empty => _empty; | |||
| public NestType NestType { get; protected set; } | |||
| public string? Name { get; set; } | |||
| public T? Value { get; protected set; } | |||
| public List<Nest<T>>? ListValue { get; protected set; } | |||
| public Dictionary<string, Nest<T>>? DictValue { get; protected set; } | |||
| protected Nest() { } | |||
| public Nest(T value, string? name = null) | |||
| { | |||
| Value = value; | |||
| Name = name; | |||
| NestType = NestType.Node; | |||
| } | |||
| public Nest(IEnumerable<Nest<T>> values, string? name = null) | |||
| { | |||
| ListValue = values.ToList(); | |||
| Name = name; | |||
| NestType = NestType.List; | |||
| } | |||
| public Nest(Dictionary<string, Nest<T>> value, string? name = null) | |||
| { | |||
| DictValue = value; | |||
| Name = name; | |||
| NestType = NestType.Dictionary; | |||
| } | |||
| public Nest(Nest<T> other) | |||
| { | |||
| NestType = other.NestType; | |||
| Value = other.Value; | |||
| DictValue = other.DictValue; | |||
| ListValue = other.ListValue; | |||
| Name = other.Name; | |||
| } | |||
| public virtual IEnumerable<T> Flatten() | |||
| { | |||
| return FlattenInternal(this); | |||
| } | |||
| public virtual INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||
| { | |||
| return MapStructureInternal(func); | |||
| } | |||
| /// <summary> | |||
| /// Pack the flat items to a nested sequence by the template. | |||
| /// </summary> | |||
| /// <param name="flatItems"></param> | |||
| /// <returns></returns> | |||
| public virtual Nest<T> PackSequence(T[] flatItems) | |||
| { | |||
| if(flatItems.Length == 0) | |||
| { | |||
| return Nest<T>.Empty; | |||
| } | |||
| int index = 0; | |||
| return PackSequenceInternal(this, flatItems, ref index); | |||
| } | |||
| private static Nest<T> PackSequenceInternal(Nest<T> template, T[] flatItems, ref int index) | |||
| { | |||
| if(template.NestType == NestType.Node) | |||
| { | |||
| if(index >= flatItems.Length) | |||
| { | |||
| throw new InvalidArgumentError("The template and flat items are not matched."); | |||
| } | |||
| return new Nest<T>(flatItems[index++]); | |||
| } | |||
| else if(template.NestType == NestType.List) | |||
| { | |||
| List<Nest<T>> nestedObjects = new List<Nest<T>>(); | |||
| for (int i = 0; i < template.ListValue!.Count; i++) | |||
| { | |||
| nestedObjects.Add(PackSequenceInternal(template.ListValue![i], flatItems, ref index)); | |||
| } | |||
| return new Nest<T>(nestedObjects); | |||
| } | |||
| else if(template.NestType == NestType.Node) | |||
| { | |||
| Dictionary<string, Nest<T>> dict = new Dictionary<string, Nest<T>>(); | |||
| foreach(var (key, value) in template.DictValue!) | |||
| { | |||
| dict[key] = PackSequenceInternal(value, flatItems, ref index); | |||
| } | |||
| return new Nest<T>(dict); | |||
| } | |||
| // Consider Empty as invalid type. | |||
| throw new InvalidArgumentError("When using `PackSequenceAs`, the template cannot contain empty node."); | |||
| } | |||
| public virtual Nest<T> AsNest() | |||
| { | |||
| return this; | |||
| } | |||
| public virtual Nest<T> MergeWith(Nest<T>? other) | |||
| { | |||
| if(other is null || other == Nest<T>.Empty) | |||
| { | |||
| return this; | |||
| } | |||
| if(this == Nest<T>.Empty) | |||
| { | |||
| return other; | |||
| } | |||
| if(NestType == NestType.Node && other.NestType == NestType.Node) | |||
| { | |||
| return new Nest<T>(new Nest<T>[] { this, other }); | |||
| } | |||
| else if(NestType == NestType.List && other.NestType == NestType.List) | |||
| { | |||
| return new Nest<T>(this.ListValue!.Concat(other.ListValue!)); | |||
| } | |||
| else if(NestType == NestType.Dictionary && other.NestType == NestType.Dictionary) | |||
| { | |||
| return new Nest<T>(this.DictValue!.Concat(other.DictValue!).ToDictionary(x => x.Key, x => x.Value)); | |||
| } | |||
| else | |||
| { | |||
| return new Nest<T>(new Nest<T>[] { this, other }); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// To see if the nested object is really nested. Despite being called `Nest`, sometimes it's actually not | |||
| /// nested. For example, [1, 2, 3] is not nested, while [1, [2, 3]] is nested. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public bool IsNested() | |||
| { | |||
| if(NestType is NestType.Empty or NestType.Node) | |||
| { | |||
| return false; | |||
| } | |||
| else if(NestType is NestType.List) | |||
| { | |||
| foreach(var item in ListValue!) | |||
| { | |||
| if(item.NestType is NestType.List or NestType.Dictionary) | |||
| { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| else | |||
| { | |||
| foreach (var item in DictValue!.Values) | |||
| { | |||
| if (item.NestType is NestType.List or NestType.Dictionary) | |||
| { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| } | |||
| [Obsolete("The indexer of Tensors is not encouraged because it leads to unclear meanings.")] | |||
| public T this[int index] | |||
| { | |||
| get | |||
| { | |||
| bool success = FindInternal(this, index, out var result); | |||
| if (success) | |||
| { | |||
| return result; | |||
| } | |||
| else | |||
| { | |||
| throw new IndexOutOfRangeException(); | |||
| } | |||
| } | |||
| set | |||
| { | |||
| bool success = SetInternal(this, index, value); | |||
| if (!success) | |||
| { | |||
| throw new IndexOutOfRangeException(); | |||
| } | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// If the existing nested structure if of type `Nest[INestStructure[T]]`, we can reduce it | |||
| /// to `Nest[T]`. | |||
| /// </summary> | |||
| /// <typeparam name="TOut"></typeparam> | |||
| /// <param name="input"></param> | |||
| /// <returns></returns> | |||
| public static Nest<T> ReduceFrom<TOut>(INestStructure<TOut> input) where TOut: INestStructure<T> | |||
| { | |||
| var nested = input.AsNest(); | |||
| return ReduceInternal(nested); | |||
| } | |||
| private static Nest<T> ReduceInternal<TOut>(Nest<TOut> node) where TOut : INestStructure<T> | |||
| { | |||
| if(node.NestType == NestType.Empty) | |||
| { | |||
| return Nest<T>.Empty; | |||
| } | |||
| else if(node.NestType == NestType.Node) | |||
| { | |||
| return node.Value!.AsNest(); | |||
| } | |||
| else if(node.NestType == NestType.List) | |||
| { | |||
| return new Nest<T>(node.ListValue!.Select(x => ReduceInternal(x))); | |||
| } | |||
| else // Dictionary type | |||
| { | |||
| return new Nest<T>(node.DictValue!.ToDictionary(x => x.Key, x => ReduceInternal(x.Value))); | |||
| } | |||
| } | |||
| private static bool FindInternal(Nest<T> node, int index, out T? result) | |||
| { | |||
| if (node.NestType == NestType.Node) | |||
| { | |||
| if(index == 0) | |||
| { | |||
| result = node.Value!; | |||
| return true; | |||
| } | |||
| result = default(T); | |||
| return false; | |||
| } | |||
| else if (node.NestType == NestType.List) | |||
| { | |||
| foreach (var item in node.ListValue!) | |||
| { | |||
| if(index == 0) | |||
| { | |||
| return FindInternal(item, index, out result); | |||
| } | |||
| index--; | |||
| } | |||
| result = default(T); | |||
| return false; | |||
| } | |||
| else if(node.NestType == NestType.Dictionary) | |||
| { | |||
| foreach (var item in node.DictValue!.Values) | |||
| { | |||
| if (index == 0) | |||
| { | |||
| return FindInternal(item, index, out result); | |||
| } | |||
| index--; | |||
| } | |||
| result = default(T); | |||
| return false; | |||
| } | |||
| else | |||
| { | |||
| result = default(T); | |||
| return false; | |||
| } | |||
| } | |||
| private static bool SetInternal(Nest<T> node, int index, T newValue) | |||
| { | |||
| if (node.NestType == NestType.Node) | |||
| { | |||
| if (index == 0) | |||
| { | |||
| node.Value = newValue; | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| else if (node.NestType == NestType.List) | |||
| { | |||
| foreach (var item in node.ListValue!) | |||
| { | |||
| if (index == 0) | |||
| { | |||
| return SetInternal(item, index, newValue); | |||
| } | |||
| index--; | |||
| } | |||
| return false; | |||
| } | |||
| else if (node.NestType == NestType.Dictionary) | |||
| { | |||
| foreach (var item in node.DictValue!.Values) | |||
| { | |||
| if (index == 0) | |||
| { | |||
| return SetInternal(item, index, newValue); | |||
| } | |||
| index--; | |||
| } | |||
| return false; | |||
| } | |||
| else | |||
| { | |||
| return false; | |||
| } | |||
| } | |||
| private static IEnumerable<T> FlattenInternal(Nest<T> node) | |||
| { | |||
| if (node.NestType == NestType.Node) | |||
| { | |||
| yield return node.Value!; | |||
| } | |||
| else if (node.NestType == NestType.List) | |||
| { | |||
| foreach (var item in node.ListValue!) | |||
| { | |||
| foreach(var val in FlattenInternal(item)) | |||
| { | |||
| yield return val; | |||
| } | |||
| } | |||
| } | |||
| else if (node.NestType == NestType.Dictionary) | |||
| { | |||
| foreach (var item in node.DictValue!.Values) | |||
| { | |||
| foreach (var val in FlattenInternal(item)) | |||
| { | |||
| yield return val; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| private Nest<TOut> MapStructureInternal<TOut>(Func<T, TOut> func) | |||
| { | |||
| if (NestType == NestType.Node) | |||
| { | |||
| return new Nest<TOut>(func(Value!)); | |||
| } | |||
| else if (NestType == NestType.List) | |||
| { | |||
| List<Nest<TOut>> outs = new List<Nest<TOut>>(); | |||
| foreach (var item in ListValue!) | |||
| { | |||
| outs.Add(item.MapStructureInternal(func)); | |||
| } | |||
| return new Nest<TOut>(outs); | |||
| } | |||
| else if (NestType == NestType.Dictionary) | |||
| { | |||
| Dictionary<string, Nest<TOut>> outs = new Dictionary<string, Nest<TOut>>(); | |||
| foreach (var (key, value) in DictValue!) | |||
| { | |||
| outs.Add(key, value.MapStructureInternal(func)); | |||
| } | |||
| return new Nest<TOut>(outs); | |||
| } | |||
| else | |||
| { | |||
| return Nest<TOut>.Empty; | |||
| } | |||
| } | |||
| public IEnumerator<T> GetEnumerator() | |||
| { | |||
| return Flatten().GetEnumerator(); | |||
| } | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| { | |||
| return GetEnumerator(); | |||
| } | |||
| public override string ToString() | |||
| { | |||
| StringBuilder sb = new StringBuilder(); | |||
| sb.Append("("); | |||
| WriteString(this, sb); | |||
| sb.Append(")"); | |||
| return sb.ToString(); | |||
| } | |||
| private static void WriteString(Nest<T> node, StringBuilder sb) | |||
| { | |||
| if (!string.IsNullOrEmpty(node.Name)) | |||
| { | |||
| sb.Append($"{node.Name}: "); | |||
| } | |||
| if (node.NestType == NestType.Node) | |||
| { | |||
| sb.Append(node.Value!.ToString()); | |||
| } | |||
| else if (node.NestType == NestType.List) | |||
| { | |||
| sb.Append("["); | |||
| for(int i = 0; i < node.ListValue!.Count; i++) | |||
| { | |||
| WriteString(node.ListValue![i], sb); | |||
| if(i != node.ListValue!.Count - 1) | |||
| { | |||
| sb.Append(", "); | |||
| } | |||
| } | |||
| sb.Append("]"); | |||
| } | |||
| else if (node.NestType == NestType.Dictionary) | |||
| { | |||
| sb.Append("{"); | |||
| int count = node.DictValue!.Count; | |||
| int i = 0; | |||
| foreach (var (key, value) in node.DictValue!) | |||
| { | |||
| sb.Append($"{key}: "); | |||
| WriteString(value, sb); | |||
| if (i != count - 1) | |||
| { | |||
| sb.Append(", "); | |||
| } | |||
| i++; | |||
| } | |||
| sb.Append("}"); | |||
| } | |||
| else | |||
| { | |||
| sb.Append("<empty>"); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,99 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| public class NestDictionary<TKey, TValue> : INestStructure<TValue>, IDictionary<TKey, TValue> where TKey : notnull | |||
| { | |||
| public IDictionary<TKey, TValue> Value { get; set; } | |||
| public NestDictionary(IDictionary<TKey, TValue> dict) | |||
| { | |||
| Value = dict; | |||
| } | |||
| public IEnumerable<TValue> Flatten() | |||
| { | |||
| return Value.Select(x => x.Value); | |||
| } | |||
| public INestStructure<TOut> MapStructure<TOut>(Func<TValue, TOut> func) | |||
| { | |||
| return new NestList<TOut>(Value.Select(x => func(x.Value))); | |||
| } | |||
| public Nest<TValue> AsNest() | |||
| { | |||
| return new Nest<TValue>(Value.Values.Select(x => new Nest<TValue>(x))); | |||
| } | |||
| // Required IDictionary<TKey, TValue> members | |||
| public int Count => Value.Count; | |||
| public bool IsReadOnly => Value.IsReadOnly; | |||
| public ICollection<TKey> Keys => Value.Keys; | |||
| public ICollection<TValue> Values => Value.Values; | |||
| public void Add(TKey key, TValue value) | |||
| { | |||
| Value.Add(key, value); | |||
| } | |||
| public void Add(KeyValuePair<TKey, TValue> item) | |||
| { | |||
| Value.Add(item); | |||
| } | |||
| public void Clear() | |||
| { | |||
| Value.Clear(); | |||
| } | |||
| public bool Contains(KeyValuePair<TKey, TValue> item) | |||
| { | |||
| return Value.Contains(item); | |||
| } | |||
| public bool ContainsKey(TKey key) | |||
| { | |||
| return Value.ContainsKey(key); | |||
| } | |||
| public void CopyTo(KeyValuePair<TKey, TValue>[] array, int arrayIndex) | |||
| { | |||
| Value.CopyTo(array, arrayIndex); | |||
| } | |||
| public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator() | |||
| { | |||
| return Value.GetEnumerator(); | |||
| } | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| { | |||
| return GetEnumerator(); | |||
| } | |||
| public bool Remove(TKey key) | |||
| { | |||
| return Value.Remove(key); | |||
| } | |||
| public bool Remove(KeyValuePair<TKey, TValue> item) | |||
| { | |||
| return Value.Remove(item); | |||
| } | |||
| public bool TryGetValue(TKey key, out TValue value) | |||
| { | |||
| return Value.TryGetValue(key, out value); | |||
| } | |||
| // Optional IDictionary<TKey, TValue> members | |||
| public TValue this[TKey key] | |||
| { | |||
| get => Value[key]; | |||
| set => Value[key] = value; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,43 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| /// <summary> | |||
| /// The implementation of a list that support nest structure, in which the depth is 1. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| public sealed class NestList<T> : INestStructure<T>, IEnumerable<T> | |||
| { | |||
| public List<T> Value { get; set; } | |||
| public NestList(IEnumerable<T> values) | |||
| { | |||
| Value = new List<T>(values); | |||
| } | |||
| public IEnumerable<T> Flatten() | |||
| { | |||
| return Value; | |||
| } | |||
| public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||
| { | |||
| return new NestList<TOut>(Value.Select(x => func(x))); | |||
| } | |||
| public Nest<T> AsNest() | |||
| { | |||
| return new Nest<T>(Value.Select(x => new Nest<T>(x))); | |||
| } | |||
| // Enumerator implementation | |||
| public IEnumerator<T> GetEnumerator() | |||
| { | |||
| return Value.GetEnumerator(); | |||
| } | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| { | |||
| return GetEnumerator(); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,32 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| /// <summary> | |||
| /// A nested structure with only one element. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| public class NestNode<T> : INestStructure<T> | |||
| { | |||
| public T Value { get; set; } | |||
| public NestNode(T value) | |||
| { | |||
| Value = value; | |||
| } | |||
| public IEnumerable<T> Flatten() | |||
| { | |||
| yield return Value; | |||
| } | |||
| public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||
| { | |||
| return new NestNode<TOut>(func(Value)); | |||
| } | |||
| public Nest<T> AsNest() | |||
| { | |||
| return new Nest<T>(Value); | |||
| } | |||
| } | |||
| } | |||
| @@ -3,6 +3,7 @@ using System; | |||
| using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -13,16 +14,14 @@ namespace Tensorflow | |||
| /// and Tensor[] from Tensors implicitily. | |||
| /// It works for tuple and scalar as well. | |||
| /// </summary> | |||
| public class Tensors : IEnumerable<Tensor>, IDisposable | |||
| public sealed class Tensors : Nest<Tensor>, IDisposable | |||
| { | |||
| List<Tensor> items = new List<Tensor>(); | |||
| public TF_DataType dtype => items.First().dtype; | |||
| public Shape shape => items.First().shape; | |||
| public int rank => items.First().rank; | |||
| public Graph graph => items.First().graph; | |||
| public TF_DataType dtype => this.First().dtype; | |||
| public Shape shape => this.First().shape; | |||
| public int rank => this.First().rank; | |||
| public Graph graph => this.First().graph; | |||
| public bool IsList { get; set; } | |||
| public int Length => items.Count(); | |||
| public int Length => this.Count(); | |||
| /// <summary> | |||
| /// Return a Tensor if `Tensors` has only one tensor, otherwise throw an exception. | |||
| /// </summary> | |||
| @@ -35,7 +34,7 @@ namespace Tensorflow | |||
| throw new ValueError("Tensors with more than one tensor cannot be " + | |||
| "implicitly converted to Tensor."); | |||
| } | |||
| return items.First(); | |||
| return this.First(); | |||
| } | |||
| } | |||
| @@ -52,150 +51,194 @@ namespace Tensorflow | |||
| throw new ValueError($"Tensors with {Length} tensor cannot be " + | |||
| "implicitly converted to Tensor."); | |||
| } | |||
| return items.FirstOrDefault(); | |||
| return this.FirstOrDefault(); | |||
| } | |||
| } | |||
| public Tensor this[int index] | |||
| public Tensor this[params string[] slices] | |||
| => this.First()[slices]; | |||
| public Tensors(Tensor tensor) : base(tensor) | |||
| { | |||
| get => items[index]; | |||
| set => items[index] = value; | |||
| } | |||
| public Tensor this[params string[] slices] | |||
| => items.First()[slices]; | |||
| public Tensors(params Tensor[] tensors) | |||
| private Tensors(Nest<Tensor> nested) : base(nested) | |||
| { | |||
| items.AddRange(tensors); | |||
| } | |||
| public Tensors(IEnumerable<Tensor> tensors) | |||
| public Tensors(params Tensor[] tensors): base(tensors.Select(x => new Nest<Tensor>(x))) | |||
| { | |||
| items.AddRange(tensors); | |||
| } | |||
| public Tensors(NDArray nd) | |||
| public Tensors(IEnumerable<Tensor> tensors): base(tensors.Select(x => new Nest<Tensor>(x))) | |||
| { | |||
| items.Add(ops.convert_to_tensor(nd)); | |||
| } | |||
| public IEnumerator<Tensor> GetEnumerator() | |||
| public Tensors(NDArray nd): base(ops.convert_to_tensor(nd)) | |||
| { | |||
| foreach (var tensor in items) | |||
| yield return tensor; | |||
| } | |||
| public bool IsSingle() | |||
| { | |||
| return Length == 1; | |||
| } | |||
| public new Tensors MergeWith(Nest<Tensor>? other) | |||
| { | |||
| return FromNest(base.MergeWith(other)); | |||
| } | |||
| [Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to add " + | |||
| "a tensor to `Tensors`, creating a new instance with your newly added tensor is a better choice.")] | |||
| public void Add(Tensor tensor) | |||
| => items.Add(tensor); | |||
| { | |||
| if(NestType == NestType.Dictionary) | |||
| { | |||
| throw new ValueError("Cannot add a tensor to dictionary type of nested tensors."); | |||
| } | |||
| else if(NestType == NestType.Node) | |||
| { | |||
| NestType = NestType.List; | |||
| ListValue = new() { new Nest<Tensor>(Value), new Nest<Tensor>(tensor) }; | |||
| Value = null; | |||
| } | |||
| else | |||
| { | |||
| ListValue.Add(new Nest<Tensor>(tensor)); | |||
| } | |||
| } | |||
| [Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to add " + | |||
| "some tensors to `Tensors`, creating a new instance with your newly added tensors is a better choice.")] | |||
| public void AddRange(IEnumerable<Tensor> tensors) | |||
| => items.AddRange(tensors); | |||
| { | |||
| if (NestType == NestType.Dictionary) | |||
| { | |||
| throw new ValueError("Cannot add a tensor to dictionary type of nested tensors."); | |||
| } | |||
| else if (NestType == NestType.Node) | |||
| { | |||
| NestType = NestType.List; | |||
| ListValue = new() { new Nest<Tensor>(Value) }; | |||
| ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x))); | |||
| Value = null; | |||
| } | |||
| else | |||
| { | |||
| ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x))); | |||
| } | |||
| } | |||
| [Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to insert " + | |||
| "a tensor to `Tensors`, creating a new instance with your newly added tensor is a better choice.")] | |||
| public void Insert(int index, Tensor tensor) | |||
| => items.Insert(index, tensor); | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| => GetEnumerator(); | |||
| { | |||
| if (NestType == NestType.List) | |||
| { | |||
| ListValue.Insert(index, new Nest<Tensor>(tensor)); | |||
| } | |||
| else if(NestType == NestType.Node) | |||
| { | |||
| NestType = NestType.List; | |||
| ListValue = new() { new Nest<Tensor>(Value) }; | |||
| ListValue.Insert(index, new Nest<Tensor>(tensor)); | |||
| Value = null; | |||
| } | |||
| else | |||
| { | |||
| throw new ValueError("Cannot add a tensor to dictionary type of nested tensors."); | |||
| } | |||
| } | |||
| public string[] StringData() | |||
| { | |||
| EnsureSingleTensor(this, "nnumpy"); | |||
| return this[0].StringData(); | |||
| return Single.StringData(); | |||
| } | |||
| public string StringData(int index) | |||
| { | |||
| EnsureSingleTensor(this, "nnumpy"); | |||
| return this[0].StringData(index); | |||
| return Single.StringData(index); | |||
| } | |||
| public NDArray numpy() | |||
| { | |||
| EnsureSingleTensor(this, "nnumpy"); | |||
| return this[0].numpy(); | |||
| return Single.numpy(); | |||
| } | |||
| [Obsolete] | |||
| public T[] ToArray<T>() where T: unmanaged | |||
| { | |||
| EnsureSingleTensor(this, $"ToArray<{typeof(T)}>"); | |||
| return this[0].ToArray<T>(); | |||
| return Single.ToArray<T>(); | |||
| } | |||
| #region Explicit Conversions | |||
| public unsafe static explicit operator bool(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to bool"); | |||
| return (bool)tensor[0]; | |||
| return (bool)tensor.Single; | |||
| } | |||
| public unsafe static explicit operator sbyte(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to sbyte"); | |||
| return (sbyte)tensor[0]; | |||
| return (sbyte)tensor.Single; | |||
| } | |||
| public unsafe static explicit operator byte(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to byte"); | |||
| return (byte)tensor[0]; | |||
| return (byte)tensor.Single; | |||
| } | |||
| public unsafe static explicit operator ushort(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to ushort"); | |||
| return (ushort)tensor[0]; | |||
| return (ushort)tensor.Single; | |||
| } | |||
| public unsafe static explicit operator short(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to short"); | |||
| return (short)tensor[0]; | |||
| return (short)tensor.Single; | |||
| } | |||
| public unsafe static explicit operator int(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to int"); | |||
| return (int)tensor[0]; | |||
| return (int)tensor.Single; | |||
| } | |||
| public unsafe static explicit operator uint(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to uint"); | |||
| return (uint)tensor[0]; | |||
| return (uint)tensor.Single; | |||
| } | |||
| public unsafe static explicit operator long(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to long"); | |||
| return (long)tensor[0]; | |||
| return (long)tensor.Single; | |||
| } | |||
| public unsafe static explicit operator ulong(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to ulong"); | |||
| return (ulong)tensor[0]; | |||
| return (ulong)tensor.Single; | |||
| } | |||
| public unsafe static explicit operator float(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to byte"); | |||
| return (byte)tensor[0]; | |||
| return (byte)tensor.Single; | |||
| } | |||
| public unsafe static explicit operator double(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to double"); | |||
| return (double)tensor[0]; | |||
| return (double)tensor.Single; | |||
| } | |||
| public unsafe static explicit operator string(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to string"); | |||
| return (string)tensor[0]; | |||
| return (string)tensor.Single; | |||
| } | |||
| public static explicit operator object[](Tensors tensors) | |||
| => tensors.items.ToArray(); | |||
| => tensors.Flatten().ToArray(); | |||
| #endregion | |||
| #region Implicit Conversions | |||
| @@ -219,52 +262,40 @@ namespace Tensorflow | |||
| => tensors?.SingleOrNull; | |||
| public static implicit operator Tensor[](Tensors tensors) | |||
| => tensors.items.ToArray(); | |||
| => tensors.Flatten().ToArray(); | |||
| #endregion | |||
| public void Deconstruct(out Tensor a, out Tensors? b) | |||
| public static Tensors? FromNest(Nest<Tensor> nested) | |||
| { | |||
| a = items[0]; | |||
| b = Length == 1? null : new Tensors(items.Skip(1)); | |||
| if(nested == Nest<Tensor>.Empty) | |||
| { | |||
| return null; | |||
| } | |||
| return new Tensors(nested); | |||
| } | |||
| private static void EnsureSingleTensor(Tensors tensors, string methodnName) | |||
| public void Deconstruct(out Tensor a, out Tensors? b) | |||
| { | |||
| if(tensors.Length == 0) | |||
| { | |||
| throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains no Tensor."); | |||
| } | |||
| else if(tensors.Length > 1) | |||
| { | |||
| throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains more than one Tensor."); | |||
| } | |||
| a = this.First(); | |||
| b = Length == 1? null : new Tensors(this.Skip(1)); | |||
| } | |||
| public override string ToString() | |||
| { | |||
| if(items.Count == 1) | |||
| if(Length == 1) | |||
| { | |||
| return items[0].ToString(); | |||
| return this.First().ToString(); | |||
| } | |||
| else | |||
| { | |||
| StringBuilder sb = new StringBuilder(); | |||
| sb.Append($"Totally {items.Count} tensors, which are {string.Join(", ", items.Select(x => x.name))}\n[\n"); | |||
| for(int i = 0; i < items.Count; i++) | |||
| { | |||
| var tensor = items[i]; | |||
| sb.Append($"Tensor {i}({tensor.name}): {tensor.ToString()}\n"); | |||
| } | |||
| sb.Append("]\n"); | |||
| return sb.ToString(); | |||
| return $"Totally {Length} tensors: {base.ToString()}"; | |||
| } | |||
| } | |||
| public void Dispose() | |||
| { | |||
| foreach (var item in items) | |||
| item.Dispose(); | |||
| foreach (var tensor in this) | |||
| tensor.Dispose(); | |||
| } | |||
| } | |||
| } | |||
| @@ -36,6 +36,7 @@ namespace Tensorflow.Util | |||
| // (np.array([3, 4]), tf.constant([3, 4])))` | |||
| // | |||
| [Obsolete] | |||
| public static class nest | |||
| { | |||
| @@ -170,39 +171,6 @@ namespace Tensorflow.Util | |||
| throw new TypeError("Type of sequence not supported (yet): " + instance.GetType()); | |||
| } | |||
| public static bool is_nested(object obj) | |||
| { | |||
| // Refer to https://www.tensorflow.org/api_docs/python/tf/nest | |||
| //if (obj is IList || obj is IDictionary || obj is ITuple) | |||
| // return true; | |||
| if (obj is IList || obj is IDictionary) | |||
| return true; | |||
| if (obj is NDArray || obj is Tensor || obj is string || obj.GetType().IsGenericType | |||
| || obj is ISet<int> || obj is ISet<float> || obj is ISet<double>) | |||
| return false; | |||
| if (obj.GetType().IsNested) return true; | |||
| // Check if the object is an IEnumerable | |||
| if (obj is IEnumerable) | |||
| { | |||
| // If it is, check if it is a nested structure | |||
| foreach (object item in (IEnumerable)obj) | |||
| { | |||
| if (is_nested(item)) | |||
| { | |||
| return true; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| else | |||
| { | |||
| // If it is not, return false | |||
| return false; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Yields the next value from the given iterable. | |||
| /// </summary> | |||