| @@ -22,5 +22,12 @@ namespace Tensorflow.Common.Extensions | |||||
| { | { | ||||
| return new Tensors(tensors); | return new Tensors(tensors); | ||||
| } | } | ||||
| public static void Deconstruct<T1, T2, T3>(this (T1, T2, T3) values, out T1 first, out T2 second, out T3 third) | |||||
| { | |||||
| first = values.Item1; | |||||
| second = values.Item2; | |||||
| third = values.Item3; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,33 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Common.Types; | |||||
| namespace Tensorflow.Common.Extensions | |||||
| { | |||||
| public static class NestExtensions | |||||
| { | |||||
| public static Tensors ToTensors(this INestable<Tensor> tensors) | |||||
| { | |||||
| return new Tensors(tensors.AsNest()); | |||||
| } | |||||
| public static Tensors? ToTensors(this Nest<Tensor> tensors) | |||||
| { | |||||
| return Tensors.FromNest(tensors); | |||||
| } | |||||
| /// <summary> | |||||
| /// If the nested object is already a nested type, this function could reduce it. | |||||
| /// For example, `Nest[Nest[T]]` can be reduced to `Nest[T]`. | |||||
| /// </summary> | |||||
| /// <typeparam name="TIn"></typeparam> | |||||
| /// <typeparam name="TOut"></typeparam> | |||||
| /// <param name="input"></param> | |||||
| /// <returns></returns> | |||||
| public static Nest<TOut> ReduceTo<TIn, TOut>(this INestStructure<TIn> input) where TIn: INestStructure<TOut> | |||||
| { | |||||
| return Nest<TOut>.ReduceFrom(input); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -5,7 +5,7 @@ using System.Text; | |||||
| namespace Tensorflow.Common.Types | namespace Tensorflow.Common.Types | ||||
| { | { | ||||
| public class GeneralizedTensorShape: IEnumerable<long?[]> | |||||
| public class GeneralizedTensorShape: IEnumerable<long?[]>, INestStructure<long?>, INestable<long?> | |||||
| { | { | ||||
| public TensorShapeConfig[] Shapes { get; set; } | public TensorShapeConfig[] Shapes { get; set; } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -63,6 +63,57 @@ namespace Tensorflow.Common.Types | |||||
| return Shapes.Select(x => new Shape(x.Items.Select(y => y is null ? -1 : y.Value).ToArray())).ToArray(); | return Shapes.Select(x => new Shape(x.Items.Select(y => y is null ? -1 : y.Value).ToArray())).ToArray(); | ||||
| } | } | ||||
| public IEnumerable<long?> Flatten() | |||||
| { | |||||
| List<long?> result = new List<long?>(); | |||||
| foreach(var shapeConfig in Shapes) | |||||
| { | |||||
| result.AddRange(shapeConfig.Items); | |||||
| } | |||||
| return result; | |||||
| } | |||||
| public INestStructure<TOut> MapStructure<TOut>(Func<long?, TOut> func) | |||||
| { | |||||
| List<Nest<TOut>> lists = new(); | |||||
| foreach(var shapeConfig in Shapes) | |||||
| { | |||||
| lists.Add(new Nest<TOut>(shapeConfig.Items.Select(x => new Nest<TOut>(func(x))))); | |||||
| } | |||||
| return new Nest<TOut>(lists); | |||||
| } | |||||
| public Nest<long?> AsNest() | |||||
| { | |||||
| Nest<long?> DealWithSingleShape(TensorShapeConfig config) | |||||
| { | |||||
| if (config.Items.Length == 0) | |||||
| { | |||||
| return Nest<long?>.Empty; | |||||
| } | |||||
| else if (config.Items.Length == 1) | |||||
| { | |||||
| return new Nest<long?>(config.Items[0]); | |||||
| } | |||||
| else | |||||
| { | |||||
| return new Nest<long?>(config.Items.Select(x => new Nest<long?>(x))); | |||||
| } | |||||
| } | |||||
| if(Shapes.Length == 0) | |||||
| { | |||||
| return Nest<long?>.Empty; | |||||
| } | |||||
| else if(Shapes.Length == 1) | |||||
| { | |||||
| return DealWithSingleShape(Shapes[0]); | |||||
| } | |||||
| else | |||||
| { | |||||
| return new Nest<long?>(Shapes.Select(s => DealWithSingleShape(s))); | |||||
| } | |||||
| } | |||||
| public IEnumerator<long?[]> GetEnumerator() | public IEnumerator<long?[]> GetEnumerator() | ||||
| { | { | ||||
| foreach (var shape in Shapes) | foreach (var shape in Shapes) | ||||
| @@ -0,0 +1,27 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Common.Types | |||||
| { | |||||
| /// <summary> | |||||
| /// This interface indicates that a class may have a nested structure and provide | |||||
| /// methods to manipulate with the structure. | |||||
| /// </summary> | |||||
| public interface INestStructure<T>: INestable<T> | |||||
| { | |||||
| /// <summary> | |||||
| /// Flatten the Nestable object. Node that if the object contains only one value, | |||||
| /// it will be flattened to an enumerable with one element. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| IEnumerable<T> Flatten(); | |||||
| /// <summary> | |||||
| /// Construct a new object with the same nested structure. | |||||
| /// </summary> | |||||
| /// <typeparam name="TOut"></typeparam> | |||||
| /// <param name="func"></param> | |||||
| /// <returns></returns> | |||||
| INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,11 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Common.Types | |||||
| { | |||||
| public interface INestable<T> | |||||
| { | |||||
| Nest<T> AsNest(); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,62 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Common.Types | |||||
| { | |||||
| public static class Nest | |||||
| { | |||||
| /// <summary> | |||||
| /// Pack the flat items to a nested sequence by the template. | |||||
| /// </summary> | |||||
| /// <typeparam name="T"></typeparam> | |||||
| /// <param name="template"></param> | |||||
| /// <param name="flatItems"></param> | |||||
| /// <returns></returns> | |||||
| public static Nest<T> PackSequenceAs<T>(INestable<T> template, T[] flatItems) | |||||
| { | |||||
| return template.AsNest().PackSequence(flatItems); | |||||
| } | |||||
| /// <summary> | |||||
| /// Pack the flat items to a nested sequence by the template. | |||||
| /// </summary> | |||||
| /// <typeparam name="T"></typeparam> | |||||
| /// <param name="template"></param> | |||||
| /// <param name="flatItems"></param> | |||||
| /// <returns></returns> | |||||
| public static Nest<T> PackSequenceAs<T>(INestable<T> template, List<T> flatItems) | |||||
| { | |||||
| return template.AsNest().PackSequence(flatItems.ToArray()); | |||||
| } | |||||
| /// <summary> | |||||
| /// Flatten the nested object. | |||||
| /// </summary> | |||||
| /// <typeparam name="T"></typeparam> | |||||
| /// <param name="nestedObject"></param> | |||||
| /// <returns></returns> | |||||
| public static IEnumerable<T> Flatten<T>(INestable<T> nestedObject) | |||||
| { | |||||
| return nestedObject.AsNest().Flatten(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Map the structure with specified function. | |||||
| /// </summary> | |||||
| /// <typeparam name="TIn"></typeparam> | |||||
| /// <typeparam name="TOut"></typeparam> | |||||
| /// <param name="func"></param> | |||||
| /// <param name="nestedObject"></param> | |||||
| /// <returns></returns> | |||||
| public static INestStructure<TOut> MapStructure<TIn, TOut>(Func<TIn, TOut> func, INestable<TIn> nestedObject) | |||||
| { | |||||
| return nestedObject.AsNest().MapStructure(func); | |||||
| } | |||||
| public static bool IsNested<T>(INestable<T> obj) | |||||
| { | |||||
| return obj.AsNest().IsNested(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,458 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Common.Extensions; | |||||
| namespace Tensorflow.Common.Types | |||||
| { | |||||
| public enum NestType | |||||
| { | |||||
| Empty, | |||||
| Node, | |||||
| List, | |||||
| Dictionary | |||||
| } | |||||
| /// <summary> | |||||
| /// A nested structure which may inclulde value, list and dictionary. | |||||
| /// Note that dictionary does not ensure the data order. When using it as IEnumerable, | |||||
| /// its order is depth-first. | |||||
| /// </summary> | |||||
| /// <typeparam name="T"></typeparam> | |||||
| public class Nest<T> : INestStructure<T>, IEnumerable<T> | |||||
| { | |||||
| private static readonly Nest<T> _empty = new Nest<T>() | |||||
| { | |||||
| NestType = NestType.Empty, | |||||
| }; | |||||
| public static Nest<T> Empty => _empty; | |||||
| public NestType NestType { get; protected set; } | |||||
| public string? Name { get; set; } | |||||
| public T? Value { get; protected set; } | |||||
| public List<Nest<T>>? ListValue { get; protected set; } | |||||
| public Dictionary<string, Nest<T>>? DictValue { get; protected set; } | |||||
| protected Nest() { } | |||||
| public Nest(T value, string? name = null) | |||||
| { | |||||
| Value = value; | |||||
| Name = name; | |||||
| NestType = NestType.Node; | |||||
| } | |||||
| public Nest(IEnumerable<Nest<T>> values, string? name = null) | |||||
| { | |||||
| ListValue = values.ToList(); | |||||
| Name = name; | |||||
| NestType = NestType.List; | |||||
| } | |||||
| public Nest(Dictionary<string, Nest<T>> value, string? name = null) | |||||
| { | |||||
| DictValue = value; | |||||
| Name = name; | |||||
| NestType = NestType.Dictionary; | |||||
| } | |||||
| public Nest(Nest<T> other) | |||||
| { | |||||
| NestType = other.NestType; | |||||
| Value = other.Value; | |||||
| DictValue = other.DictValue; | |||||
| ListValue = other.ListValue; | |||||
| Name = other.Name; | |||||
| } | |||||
| public virtual IEnumerable<T> Flatten() | |||||
| { | |||||
| return FlattenInternal(this); | |||||
| } | |||||
| public virtual INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||||
| { | |||||
| return MapStructureInternal(func); | |||||
| } | |||||
| /// <summary> | |||||
| /// Pack the flat items to a nested sequence by the template. | |||||
| /// </summary> | |||||
| /// <param name="flatItems"></param> | |||||
| /// <returns></returns> | |||||
| public virtual Nest<T> PackSequence(T[] flatItems) | |||||
| { | |||||
| if(flatItems.Length == 0) | |||||
| { | |||||
| return Nest<T>.Empty; | |||||
| } | |||||
| int index = 0; | |||||
| return PackSequenceInternal(this, flatItems, ref index); | |||||
| } | |||||
| private static Nest<T> PackSequenceInternal(Nest<T> template, T[] flatItems, ref int index) | |||||
| { | |||||
| if(template.NestType == NestType.Node) | |||||
| { | |||||
| if(index >= flatItems.Length) | |||||
| { | |||||
| throw new InvalidArgumentError("The template and flat items are not matched."); | |||||
| } | |||||
| return new Nest<T>(flatItems[index++]); | |||||
| } | |||||
| else if(template.NestType == NestType.List) | |||||
| { | |||||
| List<Nest<T>> nestedObjects = new List<Nest<T>>(); | |||||
| for (int i = 0; i < template.ListValue!.Count; i++) | |||||
| { | |||||
| nestedObjects.Add(PackSequenceInternal(template.ListValue![i], flatItems, ref index)); | |||||
| } | |||||
| return new Nest<T>(nestedObjects); | |||||
| } | |||||
| else if(template.NestType == NestType.Node) | |||||
| { | |||||
| Dictionary<string, Nest<T>> dict = new Dictionary<string, Nest<T>>(); | |||||
| foreach(var (key, value) in template.DictValue!) | |||||
| { | |||||
| dict[key] = PackSequenceInternal(value, flatItems, ref index); | |||||
| } | |||||
| return new Nest<T>(dict); | |||||
| } | |||||
| // Consider Empty as invalid type. | |||||
| throw new InvalidArgumentError("When using `PackSequenceAs`, the template cannot contain empty node."); | |||||
| } | |||||
| public virtual Nest<T> AsNest() | |||||
| { | |||||
| return this; | |||||
| } | |||||
| public virtual Nest<T> MergeWith(Nest<T>? other) | |||||
| { | |||||
| if(other is null || other == Nest<T>.Empty) | |||||
| { | |||||
| return this; | |||||
| } | |||||
| if(this == Nest<T>.Empty) | |||||
| { | |||||
| return other; | |||||
| } | |||||
| if(NestType == NestType.Node && other.NestType == NestType.Node) | |||||
| { | |||||
| return new Nest<T>(new Nest<T>[] { this, other }); | |||||
| } | |||||
| else if(NestType == NestType.List && other.NestType == NestType.List) | |||||
| { | |||||
| return new Nest<T>(this.ListValue!.Concat(other.ListValue!)); | |||||
| } | |||||
| else if(NestType == NestType.Dictionary && other.NestType == NestType.Dictionary) | |||||
| { | |||||
| return new Nest<T>(this.DictValue!.Concat(other.DictValue!).ToDictionary(x => x.Key, x => x.Value)); | |||||
| } | |||||
| else | |||||
| { | |||||
| return new Nest<T>(new Nest<T>[] { this, other }); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// To see if the nested object is really nested. Despite being called `Nest`, sometimes it's actually not | |||||
| /// nested. For example, [1, 2, 3] is not nested, while [1, [2, 3]] is nested. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public bool IsNested() | |||||
| { | |||||
| if(NestType is NestType.Empty or NestType.Node) | |||||
| { | |||||
| return false; | |||||
| } | |||||
| else if(NestType is NestType.List) | |||||
| { | |||||
| foreach(var item in ListValue!) | |||||
| { | |||||
| if(item.NestType is NestType.List or NestType.Dictionary) | |||||
| { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| else | |||||
| { | |||||
| foreach (var item in DictValue!.Values) | |||||
| { | |||||
| if (item.NestType is NestType.List or NestType.Dictionary) | |||||
| { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } | |||||
| [Obsolete("The indexer of Tensors is not encouraged because it leads to unclear meanings.")] | |||||
| public T this[int index] | |||||
| { | |||||
| get | |||||
| { | |||||
| bool success = FindInternal(this, index, out var result); | |||||
| if (success) | |||||
| { | |||||
| return result; | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new IndexOutOfRangeException(); | |||||
| } | |||||
| } | |||||
| set | |||||
| { | |||||
| bool success = SetInternal(this, index, value); | |||||
| if (!success) | |||||
| { | |||||
| throw new IndexOutOfRangeException(); | |||||
| } | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// If the existing nested structure if of type `Nest[INestStructure[T]]`, we can reduce it | |||||
| /// to `Nest[T]`. | |||||
| /// </summary> | |||||
| /// <typeparam name="TOut"></typeparam> | |||||
| /// <param name="input"></param> | |||||
| /// <returns></returns> | |||||
| public static Nest<T> ReduceFrom<TOut>(INestStructure<TOut> input) where TOut: INestStructure<T> | |||||
| { | |||||
| var nested = input.AsNest(); | |||||
| return ReduceInternal(nested); | |||||
| } | |||||
| private static Nest<T> ReduceInternal<TOut>(Nest<TOut> node) where TOut : INestStructure<T> | |||||
| { | |||||
| if(node.NestType == NestType.Empty) | |||||
| { | |||||
| return Nest<T>.Empty; | |||||
| } | |||||
| else if(node.NestType == NestType.Node) | |||||
| { | |||||
| return node.Value!.AsNest(); | |||||
| } | |||||
| else if(node.NestType == NestType.List) | |||||
| { | |||||
| return new Nest<T>(node.ListValue!.Select(x => ReduceInternal(x))); | |||||
| } | |||||
| else // Dictionary type | |||||
| { | |||||
| return new Nest<T>(node.DictValue!.ToDictionary(x => x.Key, x => ReduceInternal(x.Value))); | |||||
| } | |||||
| } | |||||
| private static bool FindInternal(Nest<T> node, int index, out T? result) | |||||
| { | |||||
| if (node.NestType == NestType.Node) | |||||
| { | |||||
| if(index == 0) | |||||
| { | |||||
| result = node.Value!; | |||||
| return true; | |||||
| } | |||||
| result = default(T); | |||||
| return false; | |||||
| } | |||||
| else if (node.NestType == NestType.List) | |||||
| { | |||||
| foreach (var item in node.ListValue!) | |||||
| { | |||||
| if(index == 0) | |||||
| { | |||||
| return FindInternal(item, index, out result); | |||||
| } | |||||
| index--; | |||||
| } | |||||
| result = default(T); | |||||
| return false; | |||||
| } | |||||
| else if(node.NestType == NestType.Dictionary) | |||||
| { | |||||
| foreach (var item in node.DictValue!.Values) | |||||
| { | |||||
| if (index == 0) | |||||
| { | |||||
| return FindInternal(item, index, out result); | |||||
| } | |||||
| index--; | |||||
| } | |||||
| result = default(T); | |||||
| return false; | |||||
| } | |||||
| else | |||||
| { | |||||
| result = default(T); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| private static bool SetInternal(Nest<T> node, int index, T newValue) | |||||
| { | |||||
| if (node.NestType == NestType.Node) | |||||
| { | |||||
| if (index == 0) | |||||
| { | |||||
| node.Value = newValue; | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| else if (node.NestType == NestType.List) | |||||
| { | |||||
| foreach (var item in node.ListValue!) | |||||
| { | |||||
| if (index == 0) | |||||
| { | |||||
| return SetInternal(item, index, newValue); | |||||
| } | |||||
| index--; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| else if (node.NestType == NestType.Dictionary) | |||||
| { | |||||
| foreach (var item in node.DictValue!.Values) | |||||
| { | |||||
| if (index == 0) | |||||
| { | |||||
| return SetInternal(item, index, newValue); | |||||
| } | |||||
| index--; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| else | |||||
| { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| private static IEnumerable<T> FlattenInternal(Nest<T> node) | |||||
| { | |||||
| if (node.NestType == NestType.Node) | |||||
| { | |||||
| yield return node.Value!; | |||||
| } | |||||
| else if (node.NestType == NestType.List) | |||||
| { | |||||
| foreach (var item in node.ListValue!) | |||||
| { | |||||
| foreach(var val in FlattenInternal(item)) | |||||
| { | |||||
| yield return val; | |||||
| } | |||||
| } | |||||
| } | |||||
| else if (node.NestType == NestType.Dictionary) | |||||
| { | |||||
| foreach (var item in node.DictValue!.Values) | |||||
| { | |||||
| foreach (var val in FlattenInternal(item)) | |||||
| { | |||||
| yield return val; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| private Nest<TOut> MapStructureInternal<TOut>(Func<T, TOut> func) | |||||
| { | |||||
| if (NestType == NestType.Node) | |||||
| { | |||||
| return new Nest<TOut>(func(Value!)); | |||||
| } | |||||
| else if (NestType == NestType.List) | |||||
| { | |||||
| List<Nest<TOut>> outs = new List<Nest<TOut>>(); | |||||
| foreach (var item in ListValue!) | |||||
| { | |||||
| outs.Add(item.MapStructureInternal(func)); | |||||
| } | |||||
| return new Nest<TOut>(outs); | |||||
| } | |||||
| else if (NestType == NestType.Dictionary) | |||||
| { | |||||
| Dictionary<string, Nest<TOut>> outs = new Dictionary<string, Nest<TOut>>(); | |||||
| foreach (var (key, value) in DictValue!) | |||||
| { | |||||
| outs.Add(key, value.MapStructureInternal(func)); | |||||
| } | |||||
| return new Nest<TOut>(outs); | |||||
| } | |||||
| else | |||||
| { | |||||
| return Nest<TOut>.Empty; | |||||
| } | |||||
| } | |||||
| public IEnumerator<T> GetEnumerator() | |||||
| { | |||||
| return Flatten().GetEnumerator(); | |||||
| } | |||||
| IEnumerator IEnumerable.GetEnumerator() | |||||
| { | |||||
| return GetEnumerator(); | |||||
| } | |||||
| public override string ToString() | |||||
| { | |||||
| StringBuilder sb = new StringBuilder(); | |||||
| sb.Append("("); | |||||
| WriteString(this, sb); | |||||
| sb.Append(")"); | |||||
| return sb.ToString(); | |||||
| } | |||||
| private static void WriteString(Nest<T> node, StringBuilder sb) | |||||
| { | |||||
| if (!string.IsNullOrEmpty(node.Name)) | |||||
| { | |||||
| sb.Append($"{node.Name}: "); | |||||
| } | |||||
| if (node.NestType == NestType.Node) | |||||
| { | |||||
| sb.Append(node.Value!.ToString()); | |||||
| } | |||||
| else if (node.NestType == NestType.List) | |||||
| { | |||||
| sb.Append("["); | |||||
| for(int i = 0; i < node.ListValue!.Count; i++) | |||||
| { | |||||
| WriteString(node.ListValue![i], sb); | |||||
| if(i != node.ListValue!.Count - 1) | |||||
| { | |||||
| sb.Append(", "); | |||||
| } | |||||
| } | |||||
| sb.Append("]"); | |||||
| } | |||||
| else if (node.NestType == NestType.Dictionary) | |||||
| { | |||||
| sb.Append("{"); | |||||
| int count = node.DictValue!.Count; | |||||
| int i = 0; | |||||
| foreach (var (key, value) in node.DictValue!) | |||||
| { | |||||
| sb.Append($"{key}: "); | |||||
| WriteString(value, sb); | |||||
| if (i != count - 1) | |||||
| { | |||||
| sb.Append(", "); | |||||
| } | |||||
| i++; | |||||
| } | |||||
| sb.Append("}"); | |||||
| } | |||||
| else | |||||
| { | |||||
| sb.Append("<empty>"); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,99 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Common.Types | |||||
| { | |||||
| public class NestDictionary<TKey, TValue> : INestStructure<TValue>, IDictionary<TKey, TValue> where TKey : notnull | |||||
| { | |||||
| public IDictionary<TKey, TValue> Value { get; set; } | |||||
| public NestDictionary(IDictionary<TKey, TValue> dict) | |||||
| { | |||||
| Value = dict; | |||||
| } | |||||
| public IEnumerable<TValue> Flatten() | |||||
| { | |||||
| return Value.Select(x => x.Value); | |||||
| } | |||||
| public INestStructure<TOut> MapStructure<TOut>(Func<TValue, TOut> func) | |||||
| { | |||||
| return new NestList<TOut>(Value.Select(x => func(x.Value))); | |||||
| } | |||||
| public Nest<TValue> AsNest() | |||||
| { | |||||
| return new Nest<TValue>(Value.Values.Select(x => new Nest<TValue>(x))); | |||||
| } | |||||
| // Required IDictionary<TKey, TValue> members | |||||
| public int Count => Value.Count; | |||||
| public bool IsReadOnly => Value.IsReadOnly; | |||||
| public ICollection<TKey> Keys => Value.Keys; | |||||
| public ICollection<TValue> Values => Value.Values; | |||||
| public void Add(TKey key, TValue value) | |||||
| { | |||||
| Value.Add(key, value); | |||||
| } | |||||
| public void Add(KeyValuePair<TKey, TValue> item) | |||||
| { | |||||
| Value.Add(item); | |||||
| } | |||||
| public void Clear() | |||||
| { | |||||
| Value.Clear(); | |||||
| } | |||||
| public bool Contains(KeyValuePair<TKey, TValue> item) | |||||
| { | |||||
| return Value.Contains(item); | |||||
| } | |||||
| public bool ContainsKey(TKey key) | |||||
| { | |||||
| return Value.ContainsKey(key); | |||||
| } | |||||
| public void CopyTo(KeyValuePair<TKey, TValue>[] array, int arrayIndex) | |||||
| { | |||||
| Value.CopyTo(array, arrayIndex); | |||||
| } | |||||
| public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator() | |||||
| { | |||||
| return Value.GetEnumerator(); | |||||
| } | |||||
| IEnumerator IEnumerable.GetEnumerator() | |||||
| { | |||||
| return GetEnumerator(); | |||||
| } | |||||
| public bool Remove(TKey key) | |||||
| { | |||||
| return Value.Remove(key); | |||||
| } | |||||
| public bool Remove(KeyValuePair<TKey, TValue> item) | |||||
| { | |||||
| return Value.Remove(item); | |||||
| } | |||||
| public bool TryGetValue(TKey key, out TValue value) | |||||
| { | |||||
| return Value.TryGetValue(key, out value); | |||||
| } | |||||
| // Optional IDictionary<TKey, TValue> members | |||||
| public TValue this[TKey key] | |||||
| { | |||||
| get => Value[key]; | |||||
| set => Value[key] = value; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,43 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Common.Types | |||||
| { | |||||
| /// <summary> | |||||
| /// The implementation of a list that support nest structure, in which the depth is 1. | |||||
| /// </summary> | |||||
| /// <typeparam name="T"></typeparam> | |||||
| public sealed class NestList<T> : INestStructure<T>, IEnumerable<T> | |||||
| { | |||||
| public List<T> Value { get; set; } | |||||
| public NestList(IEnumerable<T> values) | |||||
| { | |||||
| Value = new List<T>(values); | |||||
| } | |||||
| public IEnumerable<T> Flatten() | |||||
| { | |||||
| return Value; | |||||
| } | |||||
| public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||||
| { | |||||
| return new NestList<TOut>(Value.Select(x => func(x))); | |||||
| } | |||||
| public Nest<T> AsNest() | |||||
| { | |||||
| return new Nest<T>(Value.Select(x => new Nest<T>(x))); | |||||
| } | |||||
| // Enumerator implementation | |||||
| public IEnumerator<T> GetEnumerator() | |||||
| { | |||||
| return Value.GetEnumerator(); | |||||
| } | |||||
| IEnumerator IEnumerable.GetEnumerator() | |||||
| { | |||||
| return GetEnumerator(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,32 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Common.Types | |||||
| { | |||||
| /// <summary> | |||||
| /// A nested structure with only one element. | |||||
| /// </summary> | |||||
| /// <typeparam name="T"></typeparam> | |||||
| public class NestNode<T> : INestStructure<T> | |||||
| { | |||||
| public T Value { get; set; } | |||||
| public NestNode(T value) | |||||
| { | |||||
| Value = value; | |||||
| } | |||||
| public IEnumerable<T> Flatten() | |||||
| { | |||||
| yield return Value; | |||||
| } | |||||
| public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||||
| { | |||||
| return new NestNode<TOut>(func(Value)); | |||||
| } | |||||
| public Nest<T> AsNest() | |||||
| { | |||||
| return new Nest<T>(Value); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -3,6 +3,7 @@ using System; | |||||
| using System.Collections; | using System.Collections; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Common.Types; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -13,16 +14,14 @@ namespace Tensorflow | |||||
| /// and Tensor[] from Tensors implicitily. | /// and Tensor[] from Tensors implicitily. | ||||
| /// It works for tuple and scalar as well. | /// It works for tuple and scalar as well. | ||||
| /// </summary> | /// </summary> | ||||
| public class Tensors : IEnumerable<Tensor>, IDisposable | |||||
| public sealed class Tensors : Nest<Tensor>, IDisposable | |||||
| { | { | ||||
| List<Tensor> items = new List<Tensor>(); | |||||
| public TF_DataType dtype => items.First().dtype; | |||||
| public Shape shape => items.First().shape; | |||||
| public int rank => items.First().rank; | |||||
| public Graph graph => items.First().graph; | |||||
| public TF_DataType dtype => this.First().dtype; | |||||
| public Shape shape => this.First().shape; | |||||
| public int rank => this.First().rank; | |||||
| public Graph graph => this.First().graph; | |||||
| public bool IsList { get; set; } | public bool IsList { get; set; } | ||||
| public int Length => items.Count(); | |||||
| public int Length => this.Count(); | |||||
| /// <summary> | /// <summary> | ||||
| /// Return a Tensor if `Tensors` has only one tensor, otherwise throw an exception. | /// Return a Tensor if `Tensors` has only one tensor, otherwise throw an exception. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -35,7 +34,7 @@ namespace Tensorflow | |||||
| throw new ValueError("Tensors with more than one tensor cannot be " + | throw new ValueError("Tensors with more than one tensor cannot be " + | ||||
| "implicitly converted to Tensor."); | "implicitly converted to Tensor."); | ||||
| } | } | ||||
| return items.First(); | |||||
| return this.First(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -52,150 +51,194 @@ namespace Tensorflow | |||||
| throw new ValueError($"Tensors with {Length} tensor cannot be " + | throw new ValueError($"Tensors with {Length} tensor cannot be " + | ||||
| "implicitly converted to Tensor."); | "implicitly converted to Tensor."); | ||||
| } | } | ||||
| return items.FirstOrDefault(); | |||||
| return this.FirstOrDefault(); | |||||
| } | } | ||||
| } | } | ||||
| public Tensor this[int index] | |||||
| public Tensor this[params string[] slices] | |||||
| => this.First()[slices]; | |||||
| public Tensors(Tensor tensor) : base(tensor) | |||||
| { | { | ||||
| get => items[index]; | |||||
| set => items[index] = value; | |||||
| } | } | ||||
| public Tensor this[params string[] slices] | |||||
| => items.First()[slices]; | |||||
| public Tensors(params Tensor[] tensors) | |||||
| private Tensors(Nest<Tensor> nested) : base(nested) | |||||
| { | { | ||||
| items.AddRange(tensors); | |||||
| } | } | ||||
| public Tensors(IEnumerable<Tensor> tensors) | |||||
| public Tensors(params Tensor[] tensors): base(tensors.Select(x => new Nest<Tensor>(x))) | |||||
| { | { | ||||
| items.AddRange(tensors); | |||||
| } | } | ||||
| public Tensors(NDArray nd) | |||||
| public Tensors(IEnumerable<Tensor> tensors): base(tensors.Select(x => new Nest<Tensor>(x))) | |||||
| { | { | ||||
| items.Add(ops.convert_to_tensor(nd)); | |||||
| } | } | ||||
| public IEnumerator<Tensor> GetEnumerator() | |||||
| public Tensors(NDArray nd): base(ops.convert_to_tensor(nd)) | |||||
| { | { | ||||
| foreach (var tensor in items) | |||||
| yield return tensor; | |||||
| } | } | ||||
| public bool IsSingle() | |||||
| { | |||||
| return Length == 1; | |||||
| } | |||||
| public new Tensors MergeWith(Nest<Tensor>? other) | |||||
| { | |||||
| return FromNest(base.MergeWith(other)); | |||||
| } | |||||
| [Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to add " + | |||||
| "a tensor to `Tensors`, creating a new instance with your newly added tensor is a better choice.")] | |||||
| public void Add(Tensor tensor) | public void Add(Tensor tensor) | ||||
| => items.Add(tensor); | |||||
| { | |||||
| if(NestType == NestType.Dictionary) | |||||
| { | |||||
| throw new ValueError("Cannot add a tensor to dictionary type of nested tensors."); | |||||
| } | |||||
| else if(NestType == NestType.Node) | |||||
| { | |||||
| NestType = NestType.List; | |||||
| ListValue = new() { new Nest<Tensor>(Value), new Nest<Tensor>(tensor) }; | |||||
| Value = null; | |||||
| } | |||||
| else | |||||
| { | |||||
| ListValue.Add(new Nest<Tensor>(tensor)); | |||||
| } | |||||
| } | |||||
| [Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to add " + | |||||
| "some tensors to `Tensors`, creating a new instance with your newly added tensors is a better choice.")] | |||||
| public void AddRange(IEnumerable<Tensor> tensors) | public void AddRange(IEnumerable<Tensor> tensors) | ||||
| => items.AddRange(tensors); | |||||
| { | |||||
| if (NestType == NestType.Dictionary) | |||||
| { | |||||
| throw new ValueError("Cannot add a tensor to dictionary type of nested tensors."); | |||||
| } | |||||
| else if (NestType == NestType.Node) | |||||
| { | |||||
| NestType = NestType.List; | |||||
| ListValue = new() { new Nest<Tensor>(Value) }; | |||||
| ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x))); | |||||
| Value = null; | |||||
| } | |||||
| else | |||||
| { | |||||
| ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x))); | |||||
| } | |||||
| } | |||||
| [Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to insert " + | |||||
| "a tensor to `Tensors`, creating a new instance with your newly added tensor is a better choice.")] | |||||
| public void Insert(int index, Tensor tensor) | public void Insert(int index, Tensor tensor) | ||||
| => items.Insert(index, tensor); | |||||
| IEnumerator IEnumerable.GetEnumerator() | |||||
| => GetEnumerator(); | |||||
| { | |||||
| if (NestType == NestType.List) | |||||
| { | |||||
| ListValue.Insert(index, new Nest<Tensor>(tensor)); | |||||
| } | |||||
| else if(NestType == NestType.Node) | |||||
| { | |||||
| NestType = NestType.List; | |||||
| ListValue = new() { new Nest<Tensor>(Value) }; | |||||
| ListValue.Insert(index, new Nest<Tensor>(tensor)); | |||||
| Value = null; | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new ValueError("Cannot add a tensor to dictionary type of nested tensors."); | |||||
| } | |||||
| } | |||||
| public string[] StringData() | public string[] StringData() | ||||
| { | { | ||||
| EnsureSingleTensor(this, "nnumpy"); | |||||
| return this[0].StringData(); | |||||
| return Single.StringData(); | |||||
| } | } | ||||
| public string StringData(int index) | public string StringData(int index) | ||||
| { | { | ||||
| EnsureSingleTensor(this, "nnumpy"); | |||||
| return this[0].StringData(index); | |||||
| return Single.StringData(index); | |||||
| } | } | ||||
| public NDArray numpy() | public NDArray numpy() | ||||
| { | { | ||||
| EnsureSingleTensor(this, "nnumpy"); | |||||
| return this[0].numpy(); | |||||
| return Single.numpy(); | |||||
| } | } | ||||
| [Obsolete] | |||||
| public T[] ToArray<T>() where T: unmanaged | public T[] ToArray<T>() where T: unmanaged | ||||
| { | { | ||||
| EnsureSingleTensor(this, $"ToArray<{typeof(T)}>"); | |||||
| return this[0].ToArray<T>(); | |||||
| return Single.ToArray<T>(); | |||||
| } | } | ||||
| #region Explicit Conversions | #region Explicit Conversions | ||||
| public unsafe static explicit operator bool(Tensors tensor) | public unsafe static explicit operator bool(Tensors tensor) | ||||
| { | { | ||||
| EnsureSingleTensor(tensor, "explicit conversion to bool"); | |||||
| return (bool)tensor[0]; | |||||
| return (bool)tensor.Single; | |||||
| } | } | ||||
| public unsafe static explicit operator sbyte(Tensors tensor) | public unsafe static explicit operator sbyte(Tensors tensor) | ||||
| { | { | ||||
| EnsureSingleTensor(tensor, "explicit conversion to sbyte"); | |||||
| return (sbyte)tensor[0]; | |||||
| return (sbyte)tensor.Single; | |||||
| } | } | ||||
| public unsafe static explicit operator byte(Tensors tensor) | public unsafe static explicit operator byte(Tensors tensor) | ||||
| { | { | ||||
| EnsureSingleTensor(tensor, "explicit conversion to byte"); | |||||
| return (byte)tensor[0]; | |||||
| return (byte)tensor.Single; | |||||
| } | } | ||||
| public unsafe static explicit operator ushort(Tensors tensor) | public unsafe static explicit operator ushort(Tensors tensor) | ||||
| { | { | ||||
| EnsureSingleTensor(tensor, "explicit conversion to ushort"); | |||||
| return (ushort)tensor[0]; | |||||
| return (ushort)tensor.Single; | |||||
| } | } | ||||
| public unsafe static explicit operator short(Tensors tensor) | public unsafe static explicit operator short(Tensors tensor) | ||||
| { | { | ||||
| EnsureSingleTensor(tensor, "explicit conversion to short"); | |||||
| return (short)tensor[0]; | |||||
| return (short)tensor.Single; | |||||
| } | } | ||||
| public unsafe static explicit operator int(Tensors tensor) | public unsafe static explicit operator int(Tensors tensor) | ||||
| { | { | ||||
| EnsureSingleTensor(tensor, "explicit conversion to int"); | |||||
| return (int)tensor[0]; | |||||
| return (int)tensor.Single; | |||||
| } | } | ||||
| public unsafe static explicit operator uint(Tensors tensor) | public unsafe static explicit operator uint(Tensors tensor) | ||||
| { | { | ||||
| EnsureSingleTensor(tensor, "explicit conversion to uint"); | |||||
| return (uint)tensor[0]; | |||||
| return (uint)tensor.Single; | |||||
| } | } | ||||
| public unsafe static explicit operator long(Tensors tensor) | public unsafe static explicit operator long(Tensors tensor) | ||||
| { | { | ||||
| EnsureSingleTensor(tensor, "explicit conversion to long"); | |||||
| return (long)tensor[0]; | |||||
| return (long)tensor.Single; | |||||
| } | } | ||||
| public unsafe static explicit operator ulong(Tensors tensor) | public unsafe static explicit operator ulong(Tensors tensor) | ||||
| { | { | ||||
| EnsureSingleTensor(tensor, "explicit conversion to ulong"); | |||||
| return (ulong)tensor[0]; | |||||
| return (ulong)tensor.Single; | |||||
| } | } | ||||
| public unsafe static explicit operator float(Tensors tensor) | public unsafe static explicit operator float(Tensors tensor) | ||||
| { | { | ||||
| EnsureSingleTensor(tensor, "explicit conversion to byte"); | |||||
| return (byte)tensor[0]; | |||||
| return (byte)tensor.Single; | |||||
| } | } | ||||
| public unsafe static explicit operator double(Tensors tensor) | public unsafe static explicit operator double(Tensors tensor) | ||||
| { | { | ||||
| EnsureSingleTensor(tensor, "explicit conversion to double"); | |||||
| return (double)tensor[0]; | |||||
| return (double)tensor.Single; | |||||
| } | } | ||||
| public unsafe static explicit operator string(Tensors tensor) | public unsafe static explicit operator string(Tensors tensor) | ||||
| { | { | ||||
| EnsureSingleTensor(tensor, "explicit conversion to string"); | |||||
| return (string)tensor[0]; | |||||
| return (string)tensor.Single; | |||||
| } | } | ||||
| public static explicit operator object[](Tensors tensors) | public static explicit operator object[](Tensors tensors) | ||||
| => tensors.items.ToArray(); | |||||
| => tensors.Flatten().ToArray(); | |||||
| #endregion | #endregion | ||||
| #region Implicit Conversions | #region Implicit Conversions | ||||
| @@ -219,52 +262,40 @@ namespace Tensorflow | |||||
| => tensors?.SingleOrNull; | => tensors?.SingleOrNull; | ||||
| public static implicit operator Tensor[](Tensors tensors) | public static implicit operator Tensor[](Tensors tensors) | ||||
| => tensors.items.ToArray(); | |||||
| => tensors.Flatten().ToArray(); | |||||
| #endregion | #endregion | ||||
| public void Deconstruct(out Tensor a, out Tensors? b) | |||||
| public static Tensors? FromNest(Nest<Tensor> nested) | |||||
| { | { | ||||
| a = items[0]; | |||||
| b = Length == 1? null : new Tensors(items.Skip(1)); | |||||
| if(nested == Nest<Tensor>.Empty) | |||||
| { | |||||
| return null; | |||||
| } | |||||
| return new Tensors(nested); | |||||
| } | } | ||||
| private static void EnsureSingleTensor(Tensors tensors, string methodnName) | |||||
| public void Deconstruct(out Tensor a, out Tensors? b) | |||||
| { | { | ||||
| if(tensors.Length == 0) | |||||
| { | |||||
| throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains no Tensor."); | |||||
| } | |||||
| else if(tensors.Length > 1) | |||||
| { | |||||
| throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains more than one Tensor."); | |||||
| } | |||||
| a = this.First(); | |||||
| b = Length == 1? null : new Tensors(this.Skip(1)); | |||||
| } | } | ||||
| public override string ToString() | public override string ToString() | ||||
| { | { | ||||
| if(items.Count == 1) | |||||
| if(Length == 1) | |||||
| { | { | ||||
| return items[0].ToString(); | |||||
| return this.First().ToString(); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| StringBuilder sb = new StringBuilder(); | |||||
| sb.Append($"Totally {items.Count} tensors, which are {string.Join(", ", items.Select(x => x.name))}\n[\n"); | |||||
| for(int i = 0; i < items.Count; i++) | |||||
| { | |||||
| var tensor = items[i]; | |||||
| sb.Append($"Tensor {i}({tensor.name}): {tensor.ToString()}\n"); | |||||
| } | |||||
| sb.Append("]\n"); | |||||
| return sb.ToString(); | |||||
| return $"Totally {Length} tensors: {base.ToString()}"; | |||||
| } | } | ||||
| } | } | ||||
| public void Dispose() | public void Dispose() | ||||
| { | { | ||||
| foreach (var item in items) | |||||
| item.Dispose(); | |||||
| foreach (var tensor in this) | |||||
| tensor.Dispose(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -36,6 +36,7 @@ namespace Tensorflow.Util | |||||
| // (np.array([3, 4]), tf.constant([3, 4])))` | // (np.array([3, 4]), tf.constant([3, 4])))` | ||||
| // | // | ||||
| [Obsolete] | |||||
| public static class nest | public static class nest | ||||
| { | { | ||||
| @@ -170,39 +171,6 @@ namespace Tensorflow.Util | |||||
| throw new TypeError("Type of sequence not supported (yet): " + instance.GetType()); | throw new TypeError("Type of sequence not supported (yet): " + instance.GetType()); | ||||
| } | } | ||||
| public static bool is_nested(object obj) | |||||
| { | |||||
| // Refer to https://www.tensorflow.org/api_docs/python/tf/nest | |||||
| //if (obj is IList || obj is IDictionary || obj is ITuple) | |||||
| // return true; | |||||
| if (obj is IList || obj is IDictionary) | |||||
| return true; | |||||
| if (obj is NDArray || obj is Tensor || obj is string || obj.GetType().IsGenericType | |||||
| || obj is ISet<int> || obj is ISet<float> || obj is ISet<double>) | |||||
| return false; | |||||
| if (obj.GetType().IsNested) return true; | |||||
| // Check if the object is an IEnumerable | |||||
| if (obj is IEnumerable) | |||||
| { | |||||
| // If it is, check if it is a nested structure | |||||
| foreach (object item in (IEnumerable)obj) | |||||
| { | |||||
| if (is_nested(item)) | |||||
| { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| else | |||||
| { | |||||
| // If it is not, return false | |||||
| return false; | |||||
| } | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Yields the next value from the given iterable. | /// Yields the next value from the given iterable. | ||||
| /// </summary> | /// </summary> | ||||