feat: Support training of RNN and LSTM.tags/v0.110.0-LSTM-Model
| @@ -16,6 +16,7 @@ | |||
| using System; | |||
| using System.Runtime.InteropServices; | |||
| using static Tensorflow.CppShapeInferenceResult.Types; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -50,6 +51,19 @@ namespace Tensorflow | |||
| return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle); | |||
| } | |||
| public unsafe static byte[] ByteStringPiece(IntPtr handle) | |||
| { | |||
| byte* str_data = (byte*)handle.ToPointer(); | |||
| List<byte> bytes = new List<byte>(); | |||
| byte current = 255; | |||
| while (current != ((byte)'\0')) | |||
| { | |||
| current = *(str_data++); | |||
| bytes.Add(current); | |||
| } | |||
| return bytes.Take(bytes.Count - 1).ToArray(); | |||
| } | |||
| [UnmanagedFunctionPointer(CallingConvention.Winapi)] | |||
| public delegate void Deallocator(IntPtr data, IntPtr size, ref DeallocatorArgs args); | |||
| @@ -46,10 +46,10 @@ namespace Tensorflow | |||
| Tensor loop_vars, | |||
| int parallel_iterations = 10) | |||
| { | |||
| Func<Tensor[], Tensor> cond1 = x | |||
| Func<Tensors, Tensor> cond1 = x | |||
| => cond(x[0]); | |||
| Func<Tensor[], Tensor[]> body1 = x | |||
| Func<Tensors, Tensors> body1 = x | |||
| => new[] { body(x[0]) }; | |||
| var results = control_flow_ops.while_loop(cond1, | |||
| @@ -58,9 +58,9 @@ namespace Tensorflow | |||
| return results[0]; | |||
| } | |||
| public Tensor[] while_loop(Func<Tensor[], Tensor> cond, | |||
| Func<Tensor[], Tensor[]> body, | |||
| Tensor[] loop_vars, | |||
| public Tensor[] while_loop(Func<Tensors, Tensor> cond, | |||
| Func<Tensors, Tensors> body, | |||
| Tensors loop_vars, | |||
| int parallel_iterations = 10, | |||
| string name = null) | |||
| => control_flow_ops.while_loop(cond, body, loop_vars, | |||
| @@ -71,15 +71,15 @@ namespace Tensorflow | |||
| public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null) | |||
| => array_ops.split( | |||
| value: value, | |||
| num_split: num_split, | |||
| num_or_size_splits: num_split, | |||
| axis: axis, | |||
| name: name); | |||
| public Tensor[] split(Tensor value, int num_split, int axis, string name = null) | |||
| => array_ops.split( | |||
| value: value, | |||
| num_split: num_split, | |||
| axis: axis, | |||
| num_or_size_splits: num_split, | |||
| axis: ops.convert_to_tensor(axis), | |||
| name: name); | |||
| public Tensor ensure_shape(Tensor x, Shape shape, string name = null) | |||
| @@ -524,7 +524,7 @@ namespace Tensorflow | |||
| case Tensors tensors: | |||
| return tensors.dtype; | |||
| case IEnumerable<Tensor> tensors: | |||
| return tensors.First().dtype; | |||
| return tensors.Where(x => x is not null).First().dtype; | |||
| case RefVariable variable: | |||
| return variable.dtype; | |||
| case ResourceVariable variable: | |||
| @@ -3,16 +3,16 @@ using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Extensions | |||
| namespace Tensorflow.Common.Extensions | |||
| { | |||
| public static class JObjectExtensions | |||
| { | |||
| public static T? TryGetOrReturnNull<T>(this JObject obj, string key) | |||
| { | |||
| var res = obj[key]; | |||
| if(res is null) | |||
| if (res is null) | |||
| { | |||
| return default(T); | |||
| return default; | |||
| } | |||
| else | |||
| { | |||
| @@ -0,0 +1,38 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Extensions | |||
| { | |||
| public static class LinqExtensions | |||
| { | |||
| #if NETSTANDARD2_0 | |||
| public static IEnumerable<T> TakeLast<T>(this IEnumerable<T> sequence, int count) | |||
| { | |||
| return sequence.Skip(sequence.Count() - count); | |||
| } | |||
| public static IEnumerable<T> SkipLast<T>(this IEnumerable<T> sequence, int count) | |||
| { | |||
| return sequence.Take(sequence.Count() - count); | |||
| } | |||
| #endif | |||
| public static Tensors ToTensors(this Tensor[] tensors) | |||
| { | |||
| return new Tensors(tensors); | |||
| } | |||
| public static Tensors ToTensors(this IList<Tensor> tensors) | |||
| { | |||
| return new Tensors(tensors); | |||
| } | |||
| public static void Deconstruct<T1, T2, T3>(this (T1, T2, T3) values, out T1 first, out T2 second, out T3 third) | |||
| { | |||
| first = values.Item1; | |||
| second = values.Item2; | |||
| third = values.Item3; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,33 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Common.Extensions | |||
| { | |||
| public static class NestExtensions | |||
| { | |||
| public static Tensors ToTensors(this INestable<Tensor> tensors) | |||
| { | |||
| return new Tensors(tensors.AsNest()); | |||
| } | |||
| public static Tensors? ToTensors(this Nest<Tensor> tensors) | |||
| { | |||
| return Tensors.FromNest(tensors); | |||
| } | |||
| /// <summary> | |||
| /// If the nested object is already a nested type, this function could reduce it. | |||
| /// For example, `Nest[Nest[T]]` can be reduced to `Nest[T]`. | |||
| /// </summary> | |||
| /// <typeparam name="TIn"></typeparam> | |||
| /// <typeparam name="TOut"></typeparam> | |||
| /// <param name="input"></param> | |||
| /// <returns></returns> | |||
| public static Nest<TOut> ReduceTo<TIn, TOut>(this INestStructure<TIn> input) where TIn: INestStructure<TOut> | |||
| { | |||
| return Nest<TOut>.ReduceFrom(input); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,20 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| /// <summary> | |||
| /// This is a temp solution, which should be removed after refactoring `Tensors` | |||
| /// </summary> | |||
| [Obsolete] | |||
| public class FakeTensorByTensorArray: Tensor | |||
| { | |||
| public TensorArray TensorArray { get; set; } | |||
| public FakeTensorByTensorArray(TensorArray array) | |||
| { | |||
| TensorArray = array; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,69 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| public class GeneralizedTensorShape: Nest<Shape> | |||
| { | |||
| public GeneralizedTensorShape(Shape value, string? name = null) | |||
| { | |||
| NodeValue = value; | |||
| NestType = NestType.Node; | |||
| } | |||
| public GeneralizedTensorShape(IEnumerable<Shape> values, string? name = null) | |||
| { | |||
| ListValue = values.Select(s => new Nest<Shape>(s) as INestStructure<Shape>).ToList(); | |||
| Name = name; | |||
| NestType = NestType.List; | |||
| } | |||
| public GeneralizedTensorShape(Dictionary<string, Shape> value, string? name = null) | |||
| { | |||
| DictValue = value.ToDictionary(x => x.Key, x => new Nest<Shape>(x.Value) as INestStructure<Shape>); | |||
| Name = name; | |||
| NestType = NestType.Dictionary; | |||
| } | |||
| public GeneralizedTensorShape(Nest<Shape> other) | |||
| { | |||
| NestType = other.NestType; | |||
| NodeValue = other.NodeValue; | |||
| DictValue = other.DictValue; | |||
| ListValue = other.ListValue; | |||
| Name = other.Name; | |||
| } | |||
| public Shape ToSingleShape() | |||
| { | |||
| var shapes = Flatten().ToList(); | |||
| if (shapes.Count != 1) | |||
| { | |||
| throw new ValueError("The generalized shape contains more than 1 dim."); | |||
| } | |||
| return shapes[0]; | |||
| } | |||
| public long ToNumber() | |||
| { | |||
| var shapes = Flatten().ToList(); | |||
| if (shapes.Count != 1 || shapes[0].ndim != 1) | |||
| { | |||
| throw new ValueError("The generalized shape contains more than 1 dim."); | |||
| } | |||
| return shapes[0].dims[0]; | |||
| } | |||
| public INestStructure<TensorShapeConfig> ToTensorShapeConfigs() | |||
| { | |||
| return MapStructure(s => new TensorShapeConfig() { Items = s.dims.Select<long, long?>(x => x == -1 ? null : x).ToArray() }); | |||
| } | |||
| public static implicit operator GeneralizedTensorShape(Shape shape) | |||
| { | |||
| return new GeneralizedTensorShape(shape); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,40 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| /// <summary> | |||
| /// This interface indicates that a class may have a nested structure and provide | |||
| /// methods to manipulate with the structure. | |||
| /// </summary> | |||
| public interface INestStructure<T>: INestable<T> | |||
| { | |||
| NestType NestType { get; } | |||
| /// <summary> | |||
| /// The item count of depth 1 of the nested structure. | |||
| /// For example, [1, 2, [3, 4, 5]] has ShallowNestedCount = 3. | |||
| /// </summary> | |||
| int ShallowNestedCount { get; } | |||
| /// <summary> | |||
| /// The total item count of depth 1 of the nested structure. | |||
| /// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5. | |||
| /// </summary> | |||
| int TotalNestedCount { get; } | |||
| /// <summary> | |||
| /// Flatten the Nestable object. Node that if the object contains only one value, | |||
| /// it will be flattened to an enumerable with one element. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| IEnumerable<T> Flatten(); | |||
| /// <summary> | |||
| /// Construct a new object with the same nested structure. | |||
| /// </summary> | |||
| /// <typeparam name="TOut"></typeparam> | |||
| /// <param name="func"></param> | |||
| /// <returns></returns> | |||
| INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func); | |||
| } | |||
| } | |||
| @@ -0,0 +1,11 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| public interface INestable<T> | |||
| { | |||
| Nest<T> AsNest(); | |||
| } | |||
| } | |||
| @@ -0,0 +1,21 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| /// <summary> | |||
| /// This interface is used when some corresponding python methods have optional args. | |||
| /// For example, `Keras.Layer.Apply` generally takes three args as the inputs, while | |||
| /// `Keras.Layer.RNN` takes more. Then when calling RNN, you should add `RnnOptionalArgs` | |||
| /// as the parameter of the method. | |||
| /// </summary> | |||
| public interface IOptionalArgs | |||
| { | |||
| /// <summary> | |||
| /// The identifier of the class. It is not an argument but only something to | |||
| /// separate different OptionalArgs. | |||
| /// </summary> | |||
| string Identifier { get; } | |||
| } | |||
| } | |||
| @@ -0,0 +1,62 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| public static class Nest | |||
| { | |||
| /// <summary> | |||
| /// Pack the flat items to a nested sequence by the template. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| /// <param name="template"></param> | |||
| /// <param name="flatItems"></param> | |||
| /// <returns></returns> | |||
| public static Nest<TOut> PackSequenceAs<T, TOut>(INestable<T> template, TOut[] flatItems) | |||
| { | |||
| return template.AsNest().PackSequence(flatItems); | |||
| } | |||
| /// <summary> | |||
| /// Pack the flat items to a nested sequence by the template. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| /// <param name="template"></param> | |||
| /// <param name="flatItems"></param> | |||
| /// <returns></returns> | |||
| public static Nest<T> PackSequenceAs<T>(INestable<T> template, List<T> flatItems) | |||
| { | |||
| return template.AsNest().PackSequence(flatItems.ToArray()); | |||
| } | |||
| /// <summary> | |||
| /// Flatten the nested object. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| /// <param name="nestedObject"></param> | |||
| /// <returns></returns> | |||
| public static IEnumerable<T> Flatten<T>(INestable<T> nestedObject) | |||
| { | |||
| return nestedObject.AsNest().Flatten(); | |||
| } | |||
| /// <summary> | |||
| /// Map the structure with specified function. | |||
| /// </summary> | |||
| /// <typeparam name="TIn"></typeparam> | |||
| /// <typeparam name="TOut"></typeparam> | |||
| /// <param name="func"></param> | |||
| /// <param name="nestedObject"></param> | |||
| /// <returns></returns> | |||
| public static INestStructure<TOut> MapStructure<TIn, TOut>(Func<TIn, TOut> func, INestable<TIn> nestedObject) | |||
| { | |||
| return nestedObject.AsNest().MapStructure(func); | |||
| } | |||
| public static bool IsNested<T>(INestable<T> obj) | |||
| { | |||
| return obj.AsNest().IsNested(); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,485 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Common.Extensions; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| public enum NestType | |||
| { | |||
| Empty, | |||
| Node, | |||
| List, | |||
| Dictionary | |||
| } | |||
| /// <summary> | |||
| /// A nested structure which may inclulde value, list and dictionary. | |||
| /// Note that dictionary does not ensure the data order. When using it as IEnumerable, | |||
| /// its order is depth-first. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| public class Nest<T> : INestStructure<T>, IEnumerable<T> | |||
| { | |||
| private static readonly Nest<T> _empty = new Nest<T>() | |||
| { | |||
| NestType = NestType.Empty, | |||
| }; | |||
| public static Nest<T> Empty => _empty; | |||
| public NestType NestType { get; protected set; } | |||
| public string? Name { get; set; } | |||
| public T? NodeValue { get; protected set; } | |||
| public List<INestStructure<T>>? ListValue { get; protected set; } | |||
| public Dictionary<string, INestStructure<T>>? DictValue { get; protected set; } | |||
| public int ShallowNestedCount | |||
| { | |||
| get | |||
| { | |||
| if (NestType == NestType.Empty) | |||
| { | |||
| return 0; | |||
| } | |||
| else if (NestType == NestType.Node) | |||
| { | |||
| return 1; | |||
| } | |||
| else if (NestType == NestType.List) | |||
| { | |||
| return ListValue!.Count; | |||
| } | |||
| else // dict | |||
| { | |||
| return DictValue!.Count; | |||
| } | |||
| } | |||
| } | |||
| public int TotalNestedCount | |||
| { | |||
| get | |||
| { | |||
| return Flatten().Count(); | |||
| } | |||
| } | |||
| protected Nest() { } | |||
| public Nest(T value, string? name = null) | |||
| { | |||
| NodeValue = value; | |||
| Name = name; | |||
| NestType = NestType.Node; | |||
| } | |||
| public Nest(IEnumerable<INestStructure<T>> values, string? name = null) | |||
| { | |||
| ListValue = values.ToList(); | |||
| Name = name; | |||
| NestType = NestType.List; | |||
| } | |||
| public Nest(Dictionary<string, INestStructure<T>> value, string? name = null) | |||
| { | |||
| DictValue = value; | |||
| Name = name; | |||
| NestType = NestType.Dictionary; | |||
| } | |||
| public Nest(Nest<T> other) | |||
| { | |||
| NestType = other.NestType; | |||
| NodeValue = other.NodeValue; | |||
| DictValue = other.DictValue; | |||
| ListValue = other.ListValue; | |||
| Name = other.Name; | |||
| } | |||
| public virtual IEnumerable<T> Flatten() | |||
| { | |||
| return FlattenInternal(this); | |||
| } | |||
| public virtual INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||
| { | |||
| return MapStructureInternal(func); | |||
| } | |||
| /// <summary> | |||
| /// Pack the flat items to a nested sequence by the template. | |||
| /// </summary> | |||
| /// <param name="flatItems"></param> | |||
| /// <returns></returns> | |||
| public virtual Nest<TOut> PackSequence<TOut>(TOut[] flatItems) | |||
| { | |||
| if(flatItems.Length == 0) | |||
| { | |||
| return Nest<TOut>.Empty; | |||
| } | |||
| int index = 0; | |||
| return PackSequenceInternal(this, flatItems, ref index); | |||
| } | |||
| private static Nest<TOut> PackSequenceInternal<TOut>(Nest<T> template, TOut[] flatItems, ref int index) | |||
| { | |||
| if(template.NestType == NestType.Node) | |||
| { | |||
| if(index >= flatItems.Length) | |||
| { | |||
| throw new InvalidArgumentError("The template and flat items are not matched."); | |||
| } | |||
| return new Nest<TOut>(flatItems[index++]); | |||
| } | |||
| else if(template.NestType == NestType.List) | |||
| { | |||
| List<Nest<TOut>> nestedObjects = new List<Nest<TOut>>(); | |||
| for (int i = 0; i < template.ListValue!.Count; i++) | |||
| { | |||
| nestedObjects.Add(PackSequenceInternal(template.ListValue![i].AsNest(), flatItems, ref index)); | |||
| } | |||
| return new Nest<TOut>(nestedObjects); | |||
| } | |||
| else if(template.NestType == NestType.Node) | |||
| { | |||
| Dictionary<string, INestStructure<TOut>> dict = new Dictionary<string, INestStructure<TOut>>(); | |||
| foreach(var (key, value) in template.DictValue!) | |||
| { | |||
| dict[key] = PackSequenceInternal(value.AsNest(), flatItems, ref index); | |||
| } | |||
| return new Nest<TOut>(dict); | |||
| } | |||
| // Consider Empty as invalid type. | |||
| throw new InvalidArgumentError("When using `PackSequenceAs`, the template cannot contain empty node."); | |||
| } | |||
| public virtual Nest<T> AsNest() | |||
| { | |||
| return this; | |||
| } | |||
| public virtual Nest<T> MergeWith(Nest<T>? other) | |||
| { | |||
| if(other is null || other == Nest<T>.Empty) | |||
| { | |||
| return this; | |||
| } | |||
| if(this == Nest<T>.Empty) | |||
| { | |||
| return other; | |||
| } | |||
| if(NestType == NestType.Node && other.NestType == NestType.Node) | |||
| { | |||
| return new Nest<T>(new Nest<T>[] { this, other }); | |||
| } | |||
| else if(NestType == NestType.List && other.NestType == NestType.List) | |||
| { | |||
| return new Nest<T>(this.ListValue!.Concat(other.ListValue!)); | |||
| } | |||
| else if(NestType == NestType.Dictionary && other.NestType == NestType.Dictionary) | |||
| { | |||
| return new Nest<T>(this.DictValue!.Concat(other.DictValue!).ToDictionary(x => x.Key, x => x.Value)); | |||
| } | |||
| else | |||
| { | |||
| return new Nest<T>(new Nest<T>[] { this, other }); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// To see if the nested object is really nested. Despite being called `Nest`, sometimes it's actually not | |||
| /// nested. For example, [1, 2, 3] is not nested, while [1, [2, 3]] is nested. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public bool IsNested() | |||
| { | |||
| if(NestType is NestType.Empty or NestType.Node) | |||
| { | |||
| return false; | |||
| } | |||
| else if(NestType is NestType.List) | |||
| { | |||
| return ListValue!.Count > 0; | |||
| } | |||
| else | |||
| { | |||
| return DictValue!.Count > 0; | |||
| } | |||
| } | |||
| [Obsolete("The indexer of Tensors is not encouraged because it leads to unclear meanings.")] | |||
| public T this[int index] | |||
| { | |||
| get | |||
| { | |||
| bool success = FindInternal(this, index, out var result); | |||
| if (success) | |||
| { | |||
| return result; | |||
| } | |||
| else | |||
| { | |||
| throw new IndexOutOfRangeException(); | |||
| } | |||
| } | |||
| set | |||
| { | |||
| bool success = SetInternal(this, index, value); | |||
| if (!success) | |||
| { | |||
| throw new IndexOutOfRangeException(); | |||
| } | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// If the existing nested structure if of type `Nest[INestStructure[T]]`, we can reduce it | |||
| /// to `Nest[T]`. | |||
| /// </summary> | |||
| /// <typeparam name="TOut"></typeparam> | |||
| /// <param name="input"></param> | |||
| /// <returns></returns> | |||
| public static Nest<T> ReduceFrom<TOut>(INestStructure<TOut> input) where TOut: INestStructure<T> | |||
| { | |||
| var nested = input.AsNest(); | |||
| return ReduceInternal(nested).AsNest(); | |||
| } | |||
| private static INestStructure<T> ReduceInternal<TOut>(Nest<TOut> node) where TOut : INestStructure<T> | |||
| { | |||
| if(node.NestType == NestType.Empty) | |||
| { | |||
| return Nest<T>.Empty; | |||
| } | |||
| else if(node.NestType == NestType.Node) | |||
| { | |||
| return node.NodeValue!.AsNest(); | |||
| } | |||
| else if(node.NestType == NestType.List) | |||
| { | |||
| return new Nest<T>(node.ListValue!.Select(x => ReduceInternal(x.AsNest()))); | |||
| } | |||
| else // Dictionary type | |||
| { | |||
| return new Nest<T>(node.DictValue!.ToDictionary(x => x.Key, x => ReduceInternal(x.Value.AsNest()))); | |||
| } | |||
| } | |||
| private static bool FindInternal(Nest<T> node, int index, out T? result) | |||
| { | |||
| if (node.NestType == NestType.Node) | |||
| { | |||
| if(index == 0) | |||
| { | |||
| result = node.NodeValue!; | |||
| return true; | |||
| } | |||
| result = default(T); | |||
| return false; | |||
| } | |||
| else if (node.NestType == NestType.List) | |||
| { | |||
| foreach (var item in node.ListValue!) | |||
| { | |||
| if(index == 0) | |||
| { | |||
| return FindInternal(item.AsNest(), index, out result); | |||
| } | |||
| index--; | |||
| } | |||
| result = default(T); | |||
| return false; | |||
| } | |||
| else if(node.NestType == NestType.Dictionary) | |||
| { | |||
| foreach (var item in node.DictValue!.Values) | |||
| { | |||
| if (index == 0) | |||
| { | |||
| return FindInternal(item.AsNest(), index, out result); | |||
| } | |||
| index--; | |||
| } | |||
| result = default(T); | |||
| return false; | |||
| } | |||
| else | |||
| { | |||
| result = default(T); | |||
| return false; | |||
| } | |||
| } | |||
| private static bool SetInternal(Nest<T> node, int index, T newValue) | |||
| { | |||
| if (node.NestType == NestType.Node) | |||
| { | |||
| if (index == 0) | |||
| { | |||
| node.NodeValue = newValue; | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| else if (node.NestType == NestType.List) | |||
| { | |||
| foreach (var item in node.ListValue!) | |||
| { | |||
| if (index == 0) | |||
| { | |||
| return SetInternal(item.AsNest(), index, newValue); | |||
| } | |||
| index--; | |||
| } | |||
| return false; | |||
| } | |||
| else if (node.NestType == NestType.Dictionary) | |||
| { | |||
| foreach (var item in node.DictValue!.Values) | |||
| { | |||
| if (index == 0) | |||
| { | |||
| return SetInternal(item.AsNest(), index, newValue); | |||
| } | |||
| index--; | |||
| } | |||
| return false; | |||
| } | |||
| else | |||
| { | |||
| return false; | |||
| } | |||
| } | |||
| private static IEnumerable<T> FlattenInternal(Nest<T> node) | |||
| { | |||
| if (node.NestType == NestType.Node) | |||
| { | |||
| yield return node.NodeValue!; | |||
| } | |||
| else if (node.NestType == NestType.List) | |||
| { | |||
| foreach (var item in node.ListValue!) | |||
| { | |||
| foreach(var val in FlattenInternal(item.AsNest())) | |||
| { | |||
| yield return val; | |||
| } | |||
| } | |||
| } | |||
| else if (node.NestType == NestType.Dictionary) | |||
| { | |||
| foreach (var item in node.DictValue!.Values) | |||
| { | |||
| foreach (var val in FlattenInternal(item.AsNest())) | |||
| { | |||
| yield return val; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| private Nest<TOut> MapStructureInternal<TOut>(Func<T, TOut> func) | |||
| { | |||
| if (NestType == NestType.Node) | |||
| { | |||
| return new Nest<TOut>(func(NodeValue!)); | |||
| } | |||
| else if (NestType == NestType.List) | |||
| { | |||
| List<Nest<TOut>> outs = new List<Nest<TOut>>(); | |||
| foreach (var item in ListValue!) | |||
| { | |||
| outs.Add(item.AsNest().MapStructureInternal(func)); | |||
| } | |||
| return new Nest<TOut>(outs); | |||
| } | |||
| else if (NestType == NestType.Dictionary) | |||
| { | |||
| Dictionary<string, INestStructure<TOut>> outs = new Dictionary<string, INestStructure<TOut>>(); | |||
| foreach (var (key, value) in DictValue!) | |||
| { | |||
| outs.Add(key, value.AsNest().MapStructureInternal(func)); | |||
| } | |||
| return new Nest<TOut>(outs); | |||
| } | |||
| else | |||
| { | |||
| return Nest<TOut>.Empty; | |||
| } | |||
| } | |||
| public IEnumerator<T> GetEnumerator() | |||
| { | |||
| return Flatten().GetEnumerator(); | |||
| } | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| { | |||
| return GetEnumerator(); | |||
| } | |||
| public override string ToString() | |||
| { | |||
| StringBuilder sb = new StringBuilder(); | |||
| sb.Append("("); | |||
| WriteString(this, sb); | |||
| sb.Append(")"); | |||
| return sb.ToString(); | |||
| } | |||
| private static void WriteString(Nest<T> node, StringBuilder sb) | |||
| { | |||
| if (!string.IsNullOrEmpty(node.Name)) | |||
| { | |||
| sb.Append($"{node.Name}: "); | |||
| } | |||
| if (node.NestType == NestType.Node) | |||
| { | |||
| sb.Append(node.NodeValue!.ToString()); | |||
| } | |||
| else if (node.NestType == NestType.List) | |||
| { | |||
| sb.Append("["); | |||
| for(int i = 0; i < node.ListValue!.Count; i++) | |||
| { | |||
| WriteString(node.ListValue![i].AsNest(), sb); | |||
| if(i != node.ListValue!.Count - 1) | |||
| { | |||
| sb.Append(", "); | |||
| } | |||
| } | |||
| sb.Append("]"); | |||
| } | |||
| else if (node.NestType == NestType.Dictionary) | |||
| { | |||
| sb.Append("{"); | |||
| int count = node.DictValue!.Count; | |||
| int i = 0; | |||
| foreach (var (key, value) in node.DictValue!) | |||
| { | |||
| sb.Append($"{key}: "); | |||
| WriteString(value.AsNest(), sb); | |||
| if (i != count - 1) | |||
| { | |||
| sb.Append(", "); | |||
| } | |||
| i++; | |||
| } | |||
| sb.Append("}"); | |||
| } | |||
| else | |||
| { | |||
| sb.Append("<empty>"); | |||
| } | |||
| } | |||
| public static implicit operator Nest<T>((INestStructure<T>, INestStructure<T>) inputs) | |||
| { | |||
| return new Nest<T>(new INestStructure<T>[] { inputs.Item1, inputs.Item2 }); | |||
| } | |||
| public static implicit operator Nest<T>((INestStructure<T>, INestStructure<T>, INestStructure<T>) inputs) | |||
| { | |||
| return new Nest<T>(new INestStructure<T>[] { inputs.Item1, inputs.Item2, inputs.Item3 }); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,103 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| public class NestDictionary<TKey, TValue> : INestStructure<TValue>, IDictionary<TKey, TValue> where TKey : notnull | |||
| { | |||
| public NestType NestType => NestType.Dictionary; | |||
| public IDictionary<TKey, TValue> Value { get; set; } | |||
| public int ShallowNestedCount => Values.Count; | |||
| public int TotalNestedCount => Values.Count; | |||
| public NestDictionary(IDictionary<TKey, TValue> dict) | |||
| { | |||
| Value = dict; | |||
| } | |||
| public IEnumerable<TValue> Flatten() | |||
| { | |||
| return Value.Select(x => x.Value); | |||
| } | |||
| public INestStructure<TOut> MapStructure<TOut>(Func<TValue, TOut> func) | |||
| { | |||
| return new NestList<TOut>(Value.Select(x => func(x.Value))); | |||
| } | |||
| public Nest<TValue> AsNest() | |||
| { | |||
| return new Nest<TValue>(Value.Values.Select(x => new Nest<TValue>(x))); | |||
| } | |||
| // Required IDictionary<TKey, TValue> members | |||
| public int Count => Value.Count; | |||
| public bool IsReadOnly => Value.IsReadOnly; | |||
| public ICollection<TKey> Keys => Value.Keys; | |||
| public ICollection<TValue> Values => Value.Values; | |||
| public void Add(TKey key, TValue value) | |||
| { | |||
| Value.Add(key, value); | |||
| } | |||
| public void Add(KeyValuePair<TKey, TValue> item) | |||
| { | |||
| Value.Add(item); | |||
| } | |||
| public void Clear() | |||
| { | |||
| Value.Clear(); | |||
| } | |||
| public bool Contains(KeyValuePair<TKey, TValue> item) | |||
| { | |||
| return Value.Contains(item); | |||
| } | |||
| public bool ContainsKey(TKey key) | |||
| { | |||
| return Value.ContainsKey(key); | |||
| } | |||
| public void CopyTo(KeyValuePair<TKey, TValue>[] array, int arrayIndex) | |||
| { | |||
| Value.CopyTo(array, arrayIndex); | |||
| } | |||
| public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator() | |||
| { | |||
| return Value.GetEnumerator(); | |||
| } | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| { | |||
| return GetEnumerator(); | |||
| } | |||
| public bool Remove(TKey key) | |||
| { | |||
| return Value.Remove(key); | |||
| } | |||
| public bool Remove(KeyValuePair<TKey, TValue> item) | |||
| { | |||
| return Value.Remove(item); | |||
| } | |||
| public bool TryGetValue(TKey key, out TValue value) | |||
| { | |||
| return Value.TryGetValue(key, out value); | |||
| } | |||
| // Optional IDictionary<TKey, TValue> members | |||
| public TValue this[TKey key] | |||
| { | |||
| get => Value[key]; | |||
| set => Value[key] = value; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,53 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| /// <summary> | |||
| /// The implementation of a list that support nest structure, in which the depth is 1. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| public sealed class NestList<T> : INestStructure<T>, IEnumerable<T> | |||
| { | |||
| public NestType NestType => NestType.List; | |||
| public List<T> Values { get; set; } | |||
| public int ShallowNestedCount => Values.Count; | |||
| public int TotalNestedCount => Values.Count; | |||
| public NestList(params T[] values) | |||
| { | |||
| Values = new List<T>(values); | |||
| } | |||
| public NestList(IEnumerable<T> values) | |||
| { | |||
| Values = new List<T>(values); | |||
| } | |||
| public IEnumerable<T> Flatten() | |||
| { | |||
| return Values; | |||
| } | |||
| public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||
| { | |||
| return new NestList<TOut>(Values.Select(x => func(x))); | |||
| } | |||
| public Nest<T> AsNest() | |||
| { | |||
| return new Nest<T>(Values.Select(x => new Nest<T>(x))); | |||
| } | |||
| // Enumerator implementation | |||
| public IEnumerator<T> GetEnumerator() | |||
| { | |||
| return Values.GetEnumerator(); | |||
| } | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| { | |||
| return GetEnumerator(); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,36 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| /// <summary> | |||
| /// A nested structure with only one element. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| public class NestNode<T> : INestStructure<T> | |||
| { | |||
| public NestType NestType => NestType.Node; | |||
| public T Value { get; set; } | |||
| public int ShallowNestedCount => 1; | |||
| public int TotalNestedCount => 1; | |||
| public NestNode(T value) | |||
| { | |||
| Value = value; | |||
| } | |||
| public IEnumerable<T> Flatten() | |||
| { | |||
| yield return Value; | |||
| } | |||
| public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||
| { | |||
| return new NestNode<TOut>(func(Value)); | |||
| } | |||
| public Nest<T> AsNest() | |||
| { | |||
| return new Nest<T>(Value); | |||
| } | |||
| } | |||
| } | |||
| @@ -3,7 +3,7 @@ using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| namespace Tensorflow.Keras.Saving | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| public class TensorShapeConfig | |||
| { | |||
| @@ -161,8 +161,8 @@ namespace Tensorflow | |||
| break; | |||
| } | |||
| yield return (new Tensors(results.Take(FirstInputTensorCount)), results.Length == FirstInputTensorCount ? | |||
| null : new Tensors(results.Skip(FirstInputTensorCount))); | |||
| yield return (new Tensors(results.Take(FirstInputTensorCount).ToArray()), results.Length == FirstInputTensorCount ? | |||
| null : new Tensors(results.Skip(FirstInputTensorCount).ToArray())); | |||
| } | |||
| } | |||
| @@ -352,13 +352,19 @@ namespace Tensorflow.Eager | |||
| c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value)); | |||
| break; | |||
| case TF_AttrType.TF_ATTR_SHAPE: | |||
| var dims = (value as long[]).ToArray(); | |||
| long[] dims; | |||
| if (value is Shape shape) dims = shape.dims.ToArray(); | |||
| else if (value is long[] longs) dims = longs.ToArray(); | |||
| else if (value is int[] ints) dims = ints.Select(x => (long)x).ToArray(); | |||
| else dims = ((long[])value).ToArray(); | |||
| c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status); | |||
| status.Check(true); | |||
| break; | |||
| case TF_AttrType.TF_ATTR_FUNC: | |||
| if (value is ConcreteFunction func) | |||
| c_api.TFE_OpSetAttrFunctionName(op, key, func.func_graph.FuncName, func.func_graph.FuncName.Length); | |||
| else if(value is string str) | |||
| c_api.TFE_OpSetAttrFunctionName(op, key, str, str.Length); | |||
| else | |||
| throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC"); | |||
| break; | |||
| @@ -65,7 +65,7 @@ namespace Tensorflow.Eager | |||
| { | |||
| outgrad_vec = output_gradients.ToList(); | |||
| } | |||
| var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, false); | |||
| var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, true); | |||
| bool unconnected_gradients_zero = unconnected_gradients == "zero"; | |||
| @@ -137,7 +137,6 @@ namespace Tensorflow.Eager | |||
| { | |||
| dims[i] = c_api.TFE_TensorHandleDim(handle, i, status); | |||
| } | |||
| Shape tensor_shape = new(dims); | |||
| if(status.Code != TF_Code.TF_OK) | |||
| { | |||
| @@ -145,6 +144,7 @@ namespace Tensorflow.Eager | |||
| } | |||
| else | |||
| { | |||
| Shape tensor_shape = new(dims); | |||
| return new TapeTensor(id, dtype, tensor_shape); | |||
| } | |||
| } | |||
| @@ -173,8 +173,12 @@ namespace Tensorflow.Eager | |||
| return dtype == dtypes.variant || dtype == dtypes.resource; | |||
| } | |||
| bool ListContainNone(long[] list) | |||
| bool ListContainNone(long[]? list) | |||
| { | |||
| if(list is null) | |||
| { | |||
| return true; | |||
| } | |||
| int len = list.Length; | |||
| if(len == 0) | |||
| { | |||
| @@ -10,6 +10,11 @@ namespace Tensorflow.Eager | |||
| var str = NDArrayRender.ToString(nd); | |||
| return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}"; | |||
| } | |||
| public string ToString(int maxLength) | |||
| { | |||
| var nd = new NDArray(this); | |||
| var str = NDArrayRender.ToString(nd, maxLength); | |||
| return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}"; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,19 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Exceptions | |||
| { | |||
| public class NotOkStatusException : TensorflowException | |||
| { | |||
| public NotOkStatusException() : base() | |||
| { | |||
| } | |||
| public NotOkStatusException(string message) : base(message) | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -1,4 +1,5 @@ | |||
| using System.Linq; | |||
| using Tensorflow.Eager; | |||
| namespace Tensorflow.Framework.Models | |||
| { | |||
| @@ -24,5 +25,17 @@ namespace Tensorflow.Framework.Models | |||
| shapes.Insert(0, dim); | |||
| return new TensorSpec(shapes.ToArray(), _dtype); | |||
| } | |||
| public static TensorSpec FromTensor(Tensor tensor, string? name = null) | |||
| { | |||
| if(tensor is EagerTensor) | |||
| { | |||
| return new TensorSpec(tensor.shape, tensor.dtype, name); | |||
| } | |||
| else | |||
| { | |||
| return new TensorSpec(tensor.shape, tensor.dtype, name ?? tensor.name); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,89 @@ | |||
| using Tensorflow.Graphs; | |||
| namespace Tensorflow.Framework | |||
| { | |||
| internal static class auto_control_deps_utils | |||
| { | |||
| public static readonly string READ_ONLY_RESOURCE_INPUTS_ATTR = "_read_only_resource_inputs"; | |||
| public static List<int> get_read_only_resource_input_indices_graph(FuncGraph func_graph) | |||
| { | |||
| List<int> result = new List<int>(); | |||
| // A cache to store the read only resource inputs of an Op. | |||
| // Operation -> ObjectIdentitySet of resource handles. | |||
| Dictionary<Operation, HashSet<Tensor>> opReadOnlyResourceInputs = | |||
| new Dictionary<Operation, HashSet<Tensor>>(); | |||
| for (int inputIndex = 0; inputIndex < func_graph.Inputs.Length; inputIndex++) | |||
| { | |||
| Tensor t = func_graph.Inputs[inputIndex]; | |||
| if (t.dtype != dtypes.resource) | |||
| continue; | |||
| bool readOnly = true; | |||
| foreach (var op in t.consumers()) | |||
| { | |||
| if (opReadOnlyResourceInputs.ContainsKey(op)) | |||
| { | |||
| if (!opReadOnlyResourceInputs[op].Contains(t)) | |||
| { | |||
| readOnly = false; | |||
| break; | |||
| } | |||
| } | |||
| else | |||
| { | |||
| List<int> indices = _get_read_only_resource_input_indices_op(op); | |||
| opReadOnlyResourceInputs[op] = new HashSet<Tensor>( | |||
| indices.Select(i => op.inputs[i])); | |||
| if (!opReadOnlyResourceInputs[op].Contains(t)) | |||
| { | |||
| readOnly = false; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| if (readOnly) | |||
| result.Add(inputIndex); | |||
| } | |||
| return result; | |||
| } | |||
| private static List<int> _get_read_only_resource_input_indices_op(Operation op) | |||
| { | |||
| // ignore the RESOURCE_READ_OPS | |||
| int[] read_only_input_indices; | |||
| try | |||
| { | |||
| read_only_input_indices = op.get_attr<int[]>(READ_ONLY_RESOURCE_INPUTS_ATTR); | |||
| } | |||
| catch (InvalidArgumentError) | |||
| { | |||
| return new List<int>(); | |||
| } | |||
| int read_only_index = 0; | |||
| List<int> result = new(); | |||
| for (int i = 0; i < op.inputs.Length; i++) | |||
| { | |||
| if (read_only_index >= read_only_input_indices.Length) | |||
| { | |||
| break; | |||
| } | |||
| if (op.inputs[i].dtype != dtypes.resource) | |||
| { | |||
| continue; | |||
| } | |||
| if (read_only_index < read_only_input_indices.Length && i == read_only_input_indices[read_only_index]) | |||
| { | |||
| result.Add(i); | |||
| read_only_index++; | |||
| } | |||
| } | |||
| return result; | |||
| } | |||
| } | |||
| } | |||
| @@ -42,10 +42,10 @@ namespace Tensorflow.Framework | |||
| func_graph.as_default(); | |||
| importer.import_graph_def(graph_def, name: "", validate_colocation_constraints: false); | |||
| var input_tensor_names = fdef.Signature.InputArg.Select(x => nested_to_flat_tensor_name[x.Name]); | |||
| func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); | |||
| func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x)).ToArray()); | |||
| var output_tensor_names = fdef.Signature.OutputArg.Select(x => nested_to_flat_tensor_name[fdef.Ret[x.Name]]); | |||
| func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); | |||
| func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x)).ToArray()); | |||
| // TODO(Rinne): func_graph.ControlOutputs | |||
| _set_handle_data(func_graph, fdef); | |||
| @@ -8,6 +8,7 @@ using Tensorflow.Gradients; | |||
| using Tensorflow.Graphs; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Util; | |||
| using Tensorflow.Common.Extensions; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Functions | |||
| @@ -40,6 +41,18 @@ namespace Tensorflow.Functions | |||
| public Tensor[] FlatStructuredOutputs => func_graph.FlatStructuredOutputs; | |||
| public IEnumerable<IVariableV1> Variables => func_graph.Variables; | |||
| public IEnumerable<IVariableV1> TrainableVariables => func_graph.TrainableVariables; | |||
| internal NameAttrList AsNameAttrList | |||
| { | |||
| get | |||
| { | |||
| NameAttrList ret = new() { Name = this.Name }; | |||
| foreach (var (name, value) in _attrs) | |||
| { | |||
| ret.Attr[name] = value; | |||
| } | |||
| return ret; | |||
| } | |||
| } | |||
| public ConcreteFunction(string name) | |||
| { | |||
| @@ -90,8 +90,7 @@ namespace Tensorflow.Gradients | |||
| ? input_values[0].rank + dim_int | |||
| : dim_int % input_values[0].rank; | |||
| var sizes = input_values.Select(x => x.shape[non_neg_concat_dim]).ToArray(); | |||
| var sizes_tensor = constant_op.constant(sizes); | |||
| out_grads = array_ops.split(grad, sizes_tensor, non_neg_concat_dim).ToList(); | |||
| out_grads = array_ops.split(grad, sizes.Select(x => (int)x).ToArray(), ops.convert_to_tensor(non_neg_concat_dim)).ToList(); | |||
| } | |||
| else if (constant_op.is_constant(concat_dim)) | |||
| { | |||
| @@ -127,7 +126,7 @@ namespace Tensorflow.Gradients | |||
| new Tensor[] { non_neg_concat_dim, tf.constant(0) }, | |||
| new Tensor[] { tf.constant(1), tf.constant(-1) }); | |||
| var squeeze_sizes = array_ops.squeeze(slice); | |||
| out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_split: (int)non_neg_concat_dim).ToList(); | |||
| out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_or_size_splits: (int)non_neg_concat_dim).ToList(); | |||
| } | |||
| else | |||
| { | |||
| @@ -81,7 +81,7 @@ public class FuncGraph : Graph, IDisposable | |||
| public IEnumerable<IVariableV1> TrainableVariables => Variables.Where(v => v.Trainable); | |||
| public Dictionary<string, AttrValue> Attrs { get; set; } | |||
| Dictionary<long, (Tensor, Tensor)> _captures | |||
| internal Dictionary<long, (Tensor, Tensor)> _captures | |||
| = new Dictionary<long, (Tensor, Tensor)>(); | |||
| public Tensor[] external_captures | |||
| @@ -399,7 +399,7 @@ public class FuncGraph : Graph, IDisposable | |||
| var flat_func_args = nest.flatten(func_args as object); | |||
| var flat_func_kwargs = nest.flatten(func_kwargs as object); | |||
| func_graph.Inputs = new Tensors(flat_func_args.concat(flat_func_kwargs) | |||
| .Where(x => x is Tensor).Select(x => (Tensor)x)); | |||
| .Where(x => x is Tensor).Select(x => (Tensor)x).ToArray()); | |||
| //var func_args_before = nest.pack_sequence_as(func_args, flat_func_args, true); | |||
| //var func_kwargs_before = nest.pack_sequence_as(func_kwargs, flat_func_kwargs, true); | |||
| @@ -129,7 +129,7 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| protected Graph outer_graph; | |||
| internal Graph outer_graph; | |||
| public Graph OuterGraph => outer_graph; | |||
| public Dictionary<string, EagerDefinedFunction> Functions => _functions; | |||
| public SafeGraphHandle c_graph => _handle; | |||
| @@ -4,8 +4,6 @@ | |||
| { | |||
| // TODO: maybe change the `RNNArgs` and implement this class. | |||
| public bool UnitForgetBias { get; set; } | |||
| public float Dropout { get; set; } | |||
| public float RecurrentDropout { get; set; } | |||
| public int Implementation { get; set; } | |||
| } | |||
| } | |||
| @@ -1,7 +1,35 @@ | |||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
| using Newtonsoft.Json; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
| { | |||
| // TODO: complete the implementation | |||
| public class LSTMCellArgs : LayerArgs | |||
| public class LSTMCellArgs : AutoSerializeLayerArgs | |||
| { | |||
| [JsonProperty("units")] | |||
| public int Units { get; set; } | |||
| // TODO(Rinne): lack of initialized value of Activation. Merging keras | |||
| // into tf.net could resolve it. | |||
| [JsonProperty("activation")] | |||
| public Activation Activation { get; set; } | |||
| [JsonProperty("recurrent_activation")] | |||
| public Activation RecurrentActivation { get; set; } | |||
| [JsonProperty("use_bias")] | |||
| public bool UseBias { get; set; } = true; | |||
| [JsonProperty("dropout")] | |||
| public float Dropout { get; set; } = .0f; | |||
| [JsonProperty("recurrent_dropout")] | |||
| public float RecurrentDropout { get; set; } = .0f; | |||
| [JsonProperty("kernel_initializer")] | |||
| public IInitializer KernelInitializer { get; set; } | |||
| [JsonProperty("recurrent_initializer")] | |||
| public IInitializer RecurrentInitializer { get; set; } | |||
| [JsonProperty("bias_initializer")] | |||
| public IInitializer BiasInitializer { get; set; } | |||
| [JsonProperty("unit_forget_bias")] | |||
| public bool UnitForgetBias { get; set; } = true; | |||
| [JsonProperty("implementation")] | |||
| public int Implementation { get; set; } = 2; | |||
| } | |||
| } | |||
| @@ -1,17 +1,12 @@ | |||
| using Newtonsoft.Json; | |||
| using System.Collections.Generic; | |||
| using Tensorflow.Keras.Layers.Rnn; | |||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
| { | |||
| // TODO(Rinne): add regularizers. | |||
| public class RNNArgs : AutoSerializeLayerArgs | |||
| { | |||
| public interface IRnnArgCell : ILayer | |||
| { | |||
| object state_size { get; } | |||
| } | |||
| [JsonProperty("cell")] | |||
| // TODO: the cell should be serialized with `serialize_keras_object`. | |||
| public IRnnArgCell Cell { get; set; } = null; | |||
| [JsonProperty("return_sequences")] | |||
| public bool ReturnSequences { get; set; } = false; | |||
| [JsonProperty("return_state")] | |||
| @@ -24,8 +19,10 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
| public bool Unroll { get; set; } = false; | |||
| [JsonProperty("time_major")] | |||
| public bool TimeMajor { get; set; } = false; | |||
| public int? InputDim { get; set; } | |||
| public int? InputLength { get; set; } | |||
| // TODO: Add `num_constants` and `zero_output_for_mask`. | |||
| public Dictionary<string, object> Kwargs { get; set; } = null; | |||
| public int Units { get; set; } | |||
| public Activation Activation { get; set; } | |||
| @@ -34,21 +31,8 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
| public IInitializer KernelInitializer { get; set; } | |||
| public IInitializer RecurrentInitializer { get; set; } | |||
| public IInitializer BiasInitializer { get; set; } | |||
| // kernel_regularizer=None, | |||
| // recurrent_regularizer=None, | |||
| // bias_regularizer=None, | |||
| // activity_regularizer=None, | |||
| // kernel_constraint=None, | |||
| // recurrent_constraint=None, | |||
| // bias_constraint=None, | |||
| // dropout=0., | |||
| // recurrent_dropout=0., | |||
| // return_sequences=False, | |||
| // return_state=False, | |||
| // go_backwards=False, | |||
| // stateful=False, | |||
| // unroll=False, | |||
| // **kwargs): | |||
| public float Dropout { get; set; } = .0f; | |||
| public bool ZeroOutputForMask { get; set; } = false; | |||
| public float RecurrentDropout { get; set; } = .0f; | |||
| } | |||
| } | |||
| @@ -0,0 +1,14 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
| { | |||
| public class RnnOptionalArgs: IOptionalArgs | |||
| { | |||
| public string Identifier => "Rnn"; | |||
| public Tensor Mask { get; set; } = null; | |||
| public Tensors Constants { get; set; } = null; | |||
| } | |||
| } | |||
| @@ -0,0 +1,27 @@ | |||
| using Newtonsoft.Json; | |||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
| { | |||
| public class SimpleRNNCellArgs: AutoSerializeLayerArgs | |||
| { | |||
| [JsonProperty("units")] | |||
| public int Units { get; set; } | |||
| // TODO(Rinne): lack of initialized value of Activation. Merging keras | |||
| // into tf.net could resolve it. | |||
| [JsonProperty("activation")] | |||
| public Activation Activation { get; set; } | |||
| [JsonProperty("use_bias")] | |||
| public bool UseBias { get; set; } = true; | |||
| [JsonProperty("dropout")] | |||
| public float Dropout { get; set; } = .0f; | |||
| [JsonProperty("recurrent_dropout")] | |||
| public float RecurrentDropout { get; set; } = .0f; | |||
| [JsonProperty("kernel_initializer")] | |||
| public IInitializer KernelInitializer { get; set; } | |||
| [JsonProperty("recurrent_initializer")] | |||
| public IInitializer RecurrentInitializer { get; set; } | |||
| [JsonProperty("bias_initializer")] | |||
| public IInitializer BiasInitializer { get; set; } | |||
| } | |||
| } | |||
| @@ -1,10 +1,10 @@ | |||
| using System.Collections.Generic; | |||
| using Tensorflow.Keras.Layers.Rnn; | |||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
| { | |||
| public class StackedRNNCellsArgs : LayerArgs | |||
| { | |||
| public IList<RnnCell> Cells { get; set; } | |||
| public Dictionary<string, object> Kwargs { get; set; } = null; | |||
| public bool ReverseStateOrder = false; | |||
| } | |||
| } | |||
| @@ -1,4 +1,5 @@ | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.NumPy; | |||
| using Tensorflow.Training; | |||
| @@ -14,7 +15,7 @@ namespace Tensorflow.Keras | |||
| List<ILayer> Layers { get; } | |||
| List<INode> InboundNodes { get; } | |||
| List<INode> OutboundNodes { get; } | |||
| Tensors Apply(Tensors inputs, Tensor state = null, bool training = false); | |||
| Tensors Apply(Tensors inputs, Tensors states = null, bool training = false, IOptionalArgs? optional_args = null); | |||
| List<IVariableV1> TrainableVariables { get; } | |||
| List<IVariableV1> TrainableWeights { get; } | |||
| List<IVariableV1> NonTrainableWeights { get; } | |||
| @@ -1,5 +1,6 @@ | |||
| using System; | |||
| using Tensorflow.Framework.Models; | |||
| using Tensorflow.Keras.Layers.Rnn; | |||
| using Tensorflow.NumPy; | |||
| using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; | |||
| @@ -159,6 +160,18 @@ namespace Tensorflow.Keras.Layers | |||
| public ILayer Normalization(Shape? input_shape = null, int? axis = -1, float? mean = null, float? variance = null, bool invert = false); | |||
| public ILayer LeakyReLU(float alpha = 0.3f); | |||
| public IRnnCell LSTMCell(int uints, | |||
| string activation = "tanh", | |||
| string recurrent_activation = "sigmoid", | |||
| bool use_bias = true, | |||
| string kernel_initializer = "glorot_uniform", | |||
| string recurrent_initializer = "orthogonal", | |||
| string bias_initializer = "zeros", | |||
| bool unit_forget_bias = true, | |||
| float dropout = 0f, | |||
| float recurrent_dropout = 0f, | |||
| int implementation = 2); | |||
| public ILayer LSTM(int units, | |||
| Activation activation = null, | |||
| Activation recurrent_activation = null, | |||
| @@ -192,6 +205,19 @@ namespace Tensorflow.Keras.Layers | |||
| float offset = 0, | |||
| Shape input_shape = null); | |||
| public IRnnCell SimpleRNNCell( | |||
| int units, | |||
| string activation = "tanh", | |||
| bool use_bias = true, | |||
| string kernel_initializer = "glorot_uniform", | |||
| string recurrent_initializer = "orthogonal", | |||
| string bias_initializer = "zeros", | |||
| float dropout = 0f, | |||
| float recurrent_dropout = 0f); | |||
| public IRnnCell StackedRNNCells( | |||
| IEnumerable<IRnnCell> cells); | |||
| public ILayer SimpleRNN(int units, | |||
| string activation = "tanh", | |||
| string kernel_initializer = "glorot_uniform", | |||
| @@ -200,6 +226,26 @@ namespace Tensorflow.Keras.Layers | |||
| bool return_sequences = false, | |||
| bool return_state = false); | |||
| public ILayer RNN( | |||
| IRnnCell cell, | |||
| bool return_sequences = false, | |||
| bool return_state = false, | |||
| bool go_backwards = false, | |||
| bool stateful = false, | |||
| bool unroll = false, | |||
| bool time_major = false | |||
| ); | |||
| public ILayer RNN( | |||
| IEnumerable<IRnnCell> cell, | |||
| bool return_sequences = false, | |||
| bool return_state = false, | |||
| bool go_backwards = false, | |||
| bool stateful = false, | |||
| bool unroll = false, | |||
| bool time_major = false | |||
| ); | |||
| public ILayer Subtract(); | |||
| } | |||
| } | |||
| @@ -0,0 +1,25 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Layers.Rnn | |||
| { | |||
| public interface IRnnCell: ILayer | |||
| { | |||
| /// <summary> | |||
| /// If the derived class tends to not implement it, please return null. | |||
| /// </summary> | |||
| INestStructure<long>? StateSize { get; } | |||
| /// <summary> | |||
| /// If the derived class tends to not implement it, please return null. | |||
| /// </summary> | |||
| INestStructure<long>? OutputSize { get; } | |||
| /// <summary> | |||
| /// Whether the optional RNN args are supported when appying the layer. | |||
| /// In other words, whether `Apply` is overwrited with process of `RnnOptionalArgs`. | |||
| /// </summary> | |||
| bool SupportOptionalArgs { get; } | |||
| Tensors GetInitialState(Tensors inputs, Tensor batch_size, TF_DataType dtype); | |||
| } | |||
| } | |||
| @@ -0,0 +1,12 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.Layers.Rnn | |||
| { | |||
| public interface IStackedRnnCells : IRnnCell | |||
| { | |||
| int Count { get; } | |||
| IRnnCell this[int idx] { get; } | |||
| } | |||
| } | |||
| @@ -3,6 +3,7 @@ using Newtonsoft.Json; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Saving.Json | |||
| { | |||
| @@ -6,6 +6,7 @@ using System.Text; | |||
| using System.Diagnostics; | |||
| using OneOf.Types; | |||
| using Tensorflow.Keras.Saving.Json; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Saving | |||
| { | |||
| @@ -74,8 +74,3 @@ namespace Tensorflow | |||
| => IsScalar ? $"{axis[0]}" : $"({string.Join(", ", axis)})"; | |||
| } | |||
| } | |||
| namespace System.Runtime.CompilerServices | |||
| { | |||
| internal static class IsExternalInit { } | |||
| } | |||
| @@ -7,7 +7,7 @@ namespace Tensorflow.NumPy | |||
| { | |||
| public class NDArrayRender | |||
| { | |||
| public static string ToString(NDArray array) | |||
| public static string ToString(NDArray array, int maxLength = 10) | |||
| { | |||
| Shape shape = array.shape; | |||
| if (shape.IsScalar) | |||
| @@ -15,12 +15,12 @@ namespace Tensorflow.NumPy | |||
| var s = new StringBuilder(); | |||
| s.Append("array("); | |||
| Build(s, array); | |||
| Build(s, array, maxLength); | |||
| s.Append(")"); | |||
| return s.ToString(); | |||
| } | |||
| static void Build(StringBuilder s, NDArray array) | |||
| static void Build(StringBuilder s, NDArray array, int maxLength) | |||
| { | |||
| var shape = array.shape; | |||
| @@ -35,11 +35,11 @@ namespace Tensorflow.NumPy | |||
| var len = shape[0]; | |||
| s.Append("["); | |||
| if (len <= 10) | |||
| if (len <= maxLength) | |||
| { | |||
| for (int i = 0; i < len; i++) | |||
| { | |||
| Build(s, array[i]); | |||
| Build(s, array[i], maxLength); | |||
| if (i < len - 1) | |||
| { | |||
| s.Append(", "); | |||
| @@ -49,9 +49,9 @@ namespace Tensorflow.NumPy | |||
| } | |||
| else | |||
| { | |||
| for (int i = 0; i < 5; i++) | |||
| for (int i = 0; i < maxLength / 2; i++) | |||
| { | |||
| Build(s, array[i]); | |||
| Build(s, array[i], maxLength); | |||
| if (i < len - 1) | |||
| { | |||
| s.Append(", "); | |||
| @@ -62,9 +62,9 @@ namespace Tensorflow.NumPy | |||
| s.Append(" ... "); | |||
| s.AppendLine(); | |||
| for (int i = (int)len - 5; i < len; i++) | |||
| for (int i = (int)len - maxLength / 2; i < len; i++) | |||
| { | |||
| Build(s, array[i]); | |||
| Build(s, array[i], maxLength); | |||
| if (i < len - 1) | |||
| { | |||
| s.Append(", "); | |||
| @@ -19,13 +19,14 @@ using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.Saving.Common; | |||
| using Tensorflow.NumPy; | |||
| namespace Tensorflow | |||
| { | |||
| [JsonConverter(typeof(CustomizedShapeJsonConverter))] | |||
| public class Shape | |||
| public class Shape : INestStructure<long> | |||
| { | |||
| public int ndim => _dims == null ? -1 : _dims.Length; | |||
| long[] _dims; | |||
| @@ -41,6 +42,27 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| public NestType NestType => NestType.List; | |||
| public int ShallowNestedCount => ndim; | |||
| /// <summary> | |||
| /// The total item count of depth 1 of the nested structure. | |||
| /// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5. | |||
| /// </summary> | |||
| public int TotalNestedCount => ndim; | |||
| public IEnumerable<long> Flatten() => dims.Select(x => x); | |||
| public INestStructure<TOut> MapStructure<TOut>(Func<long, TOut> func) | |||
| { | |||
| return new NestList<TOut>(dims.Select(x => func(x))); | |||
| } | |||
| public Nest<long> AsNest() | |||
| { | |||
| return new NestList<long>(Flatten()).AsNest(); | |||
| } | |||
| #region https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/proposals/csharp-8.0/ranges | |||
| public int Length => ndim; | |||
| public long[] Slice(int start, int length) | |||
| @@ -0,0 +1,22 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.NumPy; | |||
| namespace Tensorflow.Operations.Initializers | |||
| { | |||
| /// <summary> | |||
| /// An initializer specially used for debugging (to load weights from disk). | |||
| /// </summary> | |||
| class NpyLoadInitializer : IInitializer | |||
| { | |||
| string _path; | |||
| public NpyLoadInitializer(string path) { _path = path; } | |||
| public string ClassName => ""; | |||
| public IDictionary<string, object> Config => new Dictionary<string, object>(); | |||
| public Tensor Apply(InitializerArgs args) | |||
| { | |||
| return np.load(_path); | |||
| } | |||
| } | |||
| } | |||
| @@ -53,13 +53,12 @@ public class Orthogonal : IInitializer | |||
| // Compute the qr factorization | |||
| var (q, r) = tf.linalg.qr(a, full_matrices: false); | |||
| // Make Q uniform | |||
| var d = tf.linalg.tensor_diag_part(r); | |||
| var d = tf.linalg.tensor_diag_part(r.Single); | |||
| q *= tf.sign(d); | |||
| if (num_rows < num_cols) | |||
| { | |||
| // q = tf.linalg.matrix_transpose(q); | |||
| throw new NotImplementedException(""); | |||
| q = array_ops.matrix_transpose(q); | |||
| } | |||
| return _gain * tf.reshape(q, shape); | |||
| @@ -11,6 +11,7 @@ namespace Tensorflow | |||
| /// Basic LSTM recurrent network cell. | |||
| /// The implementation is based on: http://arxiv.org/abs/1409.2329. | |||
| /// </summary> | |||
| [Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] | |||
| public class BasicLstmCell : LayerRnnCell | |||
| { | |||
| int _num_units; | |||
| @@ -88,7 +89,7 @@ namespace Tensorflow | |||
| gate_inputs = nn_ops.bias_add(gate_inputs, _bias); | |||
| // i = input_gate, j = new_input, f = forget_gate, o = output_gate | |||
| var tensors = array_ops.split(value: gate_inputs, num_split: 4, axis: one); | |||
| var tensors = array_ops.split(value: gate_inputs, num_or_size_splits: 4, axis: one); | |||
| var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]); | |||
| var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype); | |||
| @@ -20,6 +20,7 @@ using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| [Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] | |||
| public class BasicRnnCell : LayerRnnCell | |||
| { | |||
| int _num_units; | |||
| @@ -19,6 +19,7 @@ using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| [Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] | |||
| public class LayerRnnCell : RnnCell | |||
| { | |||
| protected InputSpec inputSpec; | |||
| @@ -16,10 +16,12 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Layers.Rnn; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.NumPy; | |||
| using Tensorflow.Operations; | |||
| @@ -50,7 +52,8 @@ namespace Tensorflow | |||
| /// matching structure of Tensors having shape `[batch_size].concatenate(s)` | |||
| /// for each `s` in `self.batch_size`. | |||
| /// </summary> | |||
| public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell | |||
| [Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] | |||
| public abstract class RnnCell : ILayer, IRnnCell | |||
| { | |||
| /// <summary> | |||
| /// Attribute that indicates whether the cell is a TF RNN cell, due the slight | |||
| @@ -142,7 +145,7 @@ namespace Tensorflow | |||
| throw new NotImplementedException("_zero_state_tensors"); | |||
| } | |||
| public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| public Tensors Apply(Tensors inputs, Tensors state = null, bool is_training = false, IOptionalArgs? optional_args = null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| @@ -173,5 +176,18 @@ namespace Tensorflow | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public (Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public Tensors GetInitialState(Tensors inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public INestStructure<long> StateSize => throw new NotImplementedException(); | |||
| public INestStructure<long> OutputSize => throw new NotImplementedException(); | |||
| public bool IsTFRnnCell => throw new NotImplementedException(); | |||
| public bool SupportOptionalArgs => throw new NotImplementedException(); | |||
| } | |||
| } | |||
| @@ -15,9 +15,11 @@ | |||
| ******************************************************************************/ | |||
| using Google.Protobuf; | |||
| using Google.Protobuf.Collections; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Functions; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.OpDef.Types; | |||
| @@ -387,9 +389,13 @@ namespace Tensorflow | |||
| case "list(type)": | |||
| attr_value.List.Type.AddRange((value as IList<TF_DataType>).Select(x => _MakeType(x, attr_def))); | |||
| break; | |||
| case "list(float)": | |||
| if (value != null) | |||
| attr_value.List.F.AddRange((value as IEnumerable<float>).ToArray()); | |||
| break; | |||
| case "list(int)": | |||
| if (value != null) | |||
| attr_value.List.I.AddRange((value as int[]).Select(x => Convert.ToInt64(x))); | |||
| attr_value.List.I.AddRange((value as IEnumerable<int>).Select(x => Convert.ToInt64(x))); | |||
| break; | |||
| case "bool": | |||
| attr_value.B = (bool)value; | |||
| @@ -420,6 +426,15 @@ namespace Tensorflow | |||
| case "list(shape)": | |||
| attr_value.List.Shape.AddRange((value as Shape[]).Select(x => _MakeShape(x, attr_def))); | |||
| break; | |||
| case "func": | |||
| attr_value.Func = _MakeFunc(value, attr_def.Name); | |||
| break; | |||
| case "list(func)": | |||
| attr_value.List.Func.AddRange(_MakeFuncList(value, attr_def.Name)); | |||
| break; | |||
| case "list(string)": | |||
| attr_value.List.S.AddRange((value as IEnumerable<string>).Select(x => ByteString.CopyFromUtf8(x))); | |||
| break; | |||
| default: | |||
| throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos."); | |||
| } | |||
| @@ -427,6 +442,47 @@ namespace Tensorflow | |||
| return attr_value; | |||
| } | |||
| private NameAttrList _MakeFunc(object func, string arg_name) | |||
| { | |||
| if(func is NameAttrList attrList) | |||
| { | |||
| return attrList; | |||
| } | |||
| NameAttrList fn_attr; | |||
| if(func is string funcStr) | |||
| { | |||
| fn_attr = new NameAttrList() { Name = funcStr }; | |||
| } | |||
| else if(func is ConcreteFunction concrete) | |||
| { | |||
| concrete.AddTograph(ops.get_default_graph()); | |||
| fn_attr = concrete.AsNameAttrList; | |||
| } | |||
| else if(func is EagerDefinedFunction eager) | |||
| { | |||
| eager.AddToGraph(ops.get_default_graph()); | |||
| fn_attr = new NameAttrList() { Name = eager.Name }; | |||
| } | |||
| else | |||
| { | |||
| throw new TypeError($"Don't know how to convert {func} to a func for argument {arg_name}"); | |||
| } | |||
| return fn_attr; | |||
| } | |||
| private List<NameAttrList> _MakeFuncList(object funcList, string arg_name) | |||
| { | |||
| List<NameAttrList> res = new List<NameAttrList>(); | |||
| if(funcList is IEnumerable enumerable) | |||
| { | |||
| foreach(var func in enumerable) | |||
| { | |||
| res.Add(_MakeFunc(func, arg_name)); | |||
| } | |||
| } | |||
| return res; | |||
| } | |||
| private bool _IsListParameter(ArgDef arg) | |||
| { | |||
| if (!String.IsNullOrEmpty(arg.NumberAttr)) | |||
| @@ -34,7 +34,7 @@ namespace Tensorflow | |||
| return num; | |||
| } | |||
| protected Tensor[] _outputs; | |||
| internal Tensor[] _outputs; | |||
| public virtual Tensor[] outputs => _outputs; | |||
| public Tensor output => _outputs.FirstOrDefault(); | |||
| @@ -46,9 +46,9 @@ namespace Tensorflow | |||
| /// </summary> | |||
| public partial class Operation : ITensorOrOperation | |||
| { | |||
| private readonly IntPtr _handle; // _c_op in python | |||
| protected IntPtr _handle; // _c_op in python | |||
| private readonly Graph _graph; | |||
| protected Graph _graph; | |||
| internal Func<Operation, object[], Tensor[]> _gradient_function; | |||
| @@ -69,6 +69,7 @@ namespace Tensorflow | |||
| //private OperationDescription _op_desc; | |||
| public NodeDef node_def => GetNodeDef(); | |||
| protected Operation() { } | |||
| public Operation(IntPtr handle, Graph g = null) | |||
| { | |||
| @@ -185,7 +186,16 @@ namespace Tensorflow | |||
| } | |||
| public virtual T get_attr<T>(string name) | |||
| => (T)get_attr(name); | |||
| { | |||
| if (typeof(T).IsValueType) | |||
| { | |||
| return (T)Convert.ChangeType(get_attr(name), typeof(T)); | |||
| } | |||
| else | |||
| { | |||
| return (T)get_attr(name); | |||
| } | |||
| } | |||
| internal unsafe TF_DataType _get_attr_type(string name) | |||
| { | |||
| @@ -17,6 +17,8 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Framework; | |||
| using static Tensorflow.Binding; | |||
| @@ -37,10 +39,6 @@ namespace Tensorflow.Operations | |||
| bool _infer_shape; | |||
| public override bool infer_shape => _infer_shape; | |||
| public bool _dynamic_size; | |||
| public Shape _element_shape; | |||
| public List<Tensor> _colocate_with; | |||
| Tensor _handle; | |||
| public override Tensor handle => _handle; | |||
| @@ -48,12 +46,14 @@ namespace Tensorflow.Operations | |||
| public override Tensor flow => _flow; | |||
| bool _clear_after_read; | |||
| List<Tensor> _tensor_array; | |||
| List<int> _previous_read_indices; | |||
| public _EagerTensorArray(TF_DataType dtype, Tensor size, bool dynamic_size = false, | |||
| bool clear_after_read = true, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | |||
| bool infer_shape = true, Shape? element_shape = null, | |||
| bool colocate_with_first_write_call = true, string name = null) | |||
| { | |||
| _size = size; | |||
| _flow = constant_op.constant(0); | |||
| _infer_shape = infer_shape; | |||
| _element_shape = element_shape ?? Shape.Null; | |||
| @@ -61,16 +61,20 @@ namespace Tensorflow.Operations | |||
| _dtype = dtype.as_base_dtype(); | |||
| _dynamic_size = dynamic_size; | |||
| _clear_after_read = clear_after_read; | |||
| _tensor_array = new List<Tensor>(); | |||
| _tensor_array = Enumerable.Repeat<Tensor>(null, size.numpy()).ToList(); | |||
| _previous_read_indices = new(); | |||
| } | |||
| public override TensorArray unstack(Tensor value, string name = null) | |||
| { | |||
| return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate | |||
| var tensors = array_ops.unstack(value, name: name); | |||
| if(tensors.Length > _tensor_array.Count && !_dynamic_size) | |||
| { | |||
| var num_elements = array_ops.shape(value)[0]; | |||
| return scatter(indices: math_ops.range(0, num_elements), value: value, name: name); | |||
| }); | |||
| throw new ValueError($"Cannot unstack {tensors.Length} tensors into a TensorArray of static size {_tensor_array.Count}"); | |||
| } | |||
| _tensor_array = tensors.ToList(); | |||
| // TODO(Rinne): revise the implementation. Here we should return `parent()`. | |||
| return this; | |||
| } | |||
| public TensorArray scatter(Tensor indices, Tensor value, string name = null) | |||
| @@ -103,7 +107,19 @@ namespace Tensorflow.Operations | |||
| return ta; | |||
| });*/ | |||
| throw new NotImplementedException(""); | |||
| //if (indices is EagerTensor) | |||
| //{ | |||
| // indices = indices as EagerTensor; | |||
| // indices = indices.numpy(); | |||
| //} | |||
| //foreach (var (index, val) in zip(indices.ToArray<int>(), array_ops.unstack(value))) | |||
| //{ | |||
| // this.write(index, val); | |||
| //} | |||
| //return base; | |||
| //throw new NotImplementedException(""); | |||
| return this; | |||
| } | |||
| public void _merge_element_shape(Shape shape) | |||
| @@ -116,9 +132,19 @@ namespace Tensorflow.Operations | |||
| _colocate_with.Add(value); | |||
| } | |||
| private Tensor _maybe_zero(int ix) | |||
| { | |||
| var val = _tensor_array[ix]; | |||
| if(val is null) | |||
| { | |||
| val = _tensor_array[ix] = array_ops.zeros(_element_shape, _dtype); | |||
| } | |||
| return val; | |||
| } | |||
| public override Tensor read<T>(T index, string name = null) | |||
| { | |||
| int index_int = -1; | |||
| int index_int; | |||
| if (index is int int_index) | |||
| index_int = int_index; | |||
| else if (index is Tensor tensor_index) | |||
| @@ -126,27 +152,75 @@ namespace Tensorflow.Operations | |||
| else | |||
| throw new ValueError(""); | |||
| if(index_int >= _tensor_array.Count) | |||
| { | |||
| throw new OutOfRangeError($"Tried to read from index {index_int} but array size is: {_tensor_array.Count} "); | |||
| } | |||
| var res = _tensor_array[index_int]; | |||
| if(res is null) | |||
| { | |||
| if (_previous_read_indices.Contains(index_int)) | |||
| { | |||
| throw new InvalidArgumentError($"Could not read index {index_int} twice because it was cleared after " + | |||
| $"a previous read (perhaps try setting clear_after_read = false?)"); | |||
| } | |||
| else | |||
| { | |||
| res = _maybe_zero(index_int); | |||
| } | |||
| } | |||
| if (_clear_after_read) | |||
| { | |||
| _tensor_array[index_int] = null; | |||
| _previous_read_indices.Add(index_int); | |||
| } | |||
| return _tensor_array[index_int]; | |||
| return res; | |||
| } | |||
| public override TensorArray write(Tensor index, Tensor value, string name = null) | |||
| { | |||
| if (_infer_shape) | |||
| _element_shape = _element_shape.merge_with(value.shape); | |||
| _tensor_array.add(value); | |||
| return this; | |||
| int index_int; | |||
| if(index is EagerTensor eager) | |||
| { | |||
| return write<Tensor>(eager.numpy(), value, name); | |||
| } | |||
| throw new InvalidArgumentError("The index is supposed to be an EagerTensor"); | |||
| } | |||
| public override TensorArray write<T>(int index, T value, string name = null) | |||
| { | |||
| var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | |||
| var index_tensor = ops.convert_to_tensor(index, name: "index"); | |||
| return write(index_tensor, value_tensor, name: name); | |||
| int size = _tensor_array.Count; | |||
| if(index >= size) | |||
| { | |||
| if (!_dynamic_size) | |||
| { | |||
| throw new OutOfRangeError($"Tried to write to index {index} but array is not resizeable and size " + | |||
| $"is: {size} "); | |||
| } | |||
| _tensor_array.AddRange(Enumerable.Repeat<Tensor>(null, index - size + 1)); | |||
| } | |||
| Tensor tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | |||
| if(_dtype != tensor.dtype) | |||
| { | |||
| throw new InvalidArgumentError($"TensorArray dtype is {_dtype.as_python_name()} but Op is " + | |||
| $"trying to write dtype {tensor.dtype.as_python_name()} "); | |||
| } | |||
| if (!_element_shape.is_compatible_with(tensor.shape)) | |||
| { | |||
| throw new ValueError($"Incompatible shape for value ({tensor.shape}), expected ({_element_shape})"); | |||
| } | |||
| if (_infer_shape) | |||
| { | |||
| _element_shape = _element_shape.merge_with(tensor.shape); | |||
| } | |||
| _tensor_array[index] = tensor; | |||
| return this; | |||
| } | |||
| private Tensor size(string name = null) | |||
| @@ -156,11 +230,26 @@ namespace Tensorflow.Operations | |||
| public override Tensor stack(string name = null) | |||
| { | |||
| ops.colocate_with(_handle); | |||
| return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate | |||
| if(_tensor_array.Count > 0) | |||
| { | |||
| for(int i = 0; i < _tensor_array.Count; i++) | |||
| { | |||
| _maybe_zero(i); | |||
| } | |||
| } | |||
| if(_tensor_array.Count == 0 && _element_shape.IsFullyDefined) | |||
| { | |||
| return ops.convert_to_tensor(new Shape(new long[] { 0 }.Concat(_element_shape.dims).ToArray()), name: name, dtype: _dtype); | |||
| } | |||
| else | |||
| { | |||
| return gather(math_ops.range(0, size()), name: name); | |||
| }); | |||
| return ops.convert_to_tensor(_tensor_array, name: name, dtype: _dtype); | |||
| } | |||
| //ops.colocate_with(_handle); | |||
| //return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate | |||
| //{ | |||
| // return gather(math_ops.range(0, size()), name: name); | |||
| //}); | |||
| } | |||
| public override Tensor gather(Tensor indices, string name = null) | |||
| @@ -16,7 +16,10 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Eager; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Operations | |||
| @@ -32,18 +35,18 @@ namespace Tensorflow.Operations | |||
| /// first tensor written to it. | |||
| /// </summary> | |||
| bool _colocate_with_first_write_call; | |||
| public bool colocate_with_first_write_call => _colocate_with_first_write_call; | |||
| public override bool colocate_with_first_write_call => _colocate_with_first_write_call; | |||
| bool _infer_shape; | |||
| public bool infer_shape => _infer_shape; | |||
| public bool _dynamic_size; | |||
| public override bool infer_shape => _infer_shape; | |||
| public List<Shape> _element_shape; | |||
| public List<Tensor> _colocate_with; | |||
| internal Tensor _handle; | |||
| public Tensor handle => _handle; | |||
| public override Tensor handle => _handle; | |||
| internal Tensor _flow; | |||
| public override Tensor flow => _flow; | |||
| public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null, | |||
| bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | |||
| @@ -54,6 +57,7 @@ namespace Tensorflow.Operations | |||
| dynamic_size = dynamic_size ?? false; | |||
| _dynamic_size = dynamic_size.Value; | |||
| _dtype = dtype; | |||
| _size = size; | |||
| _colocate_with_first_write_call = colocate_with_first_write_call; | |||
| if (colocate_with_first_write_call) | |||
| @@ -146,7 +150,9 @@ namespace Tensorflow.Operations | |||
| return ta; | |||
| });*/ | |||
| throw new NotImplementedException(""); | |||
| //throw new NotImplementedException(""); | |||
| return this; | |||
| } | |||
| public void _merge_element_shape(Shape shape) | |||
| @@ -232,4 +238,173 @@ namespace Tensorflow.Operations | |||
| return value; | |||
| } | |||
| } | |||
| public class _GraphTensorArrayV2 : TensorArray | |||
| { | |||
| internal TF_DataType _dtype; | |||
| public override TF_DataType dtype => _dtype; | |||
| /// <summary> | |||
| /// Used to keep track of what tensors the TensorArray should be | |||
| /// colocated with. We choose to colocate the TensorArray with the | |||
| /// first tensor written to it. | |||
| /// </summary> | |||
| bool _colocate_with_first_write_call; | |||
| public override bool colocate_with_first_write_call => _colocate_with_first_write_call; | |||
| bool _infer_shape; | |||
| public override bool infer_shape => _infer_shape; | |||
| public Shape _element_shape; | |||
| public List<Tensor> _colocate_with; | |||
| internal Tensor _handle; | |||
| public override Tensor handle => _handle; | |||
| internal Tensor _flow; | |||
| public override Tensor flow => _flow; | |||
| public _GraphTensorArrayV2(TF_DataType dtype, Tensor size, bool? dynamic_size = null, | |||
| bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | |||
| bool infer_shape = true, Shape? element_shape = null, | |||
| bool colocate_with_first_write_call = true, string name = null) | |||
| { | |||
| Debug.Assert(handle is null); | |||
| dynamic_size = dynamic_size ?? false; | |||
| _dynamic_size = dynamic_size.Value; | |||
| _size = size; | |||
| if(flow is not null && flow.dtype != dtypes.variant) | |||
| { | |||
| throw new TypeError($"Expected `flow` to be a variant tensor, but received `{flow.dtype}` instead"); | |||
| } | |||
| if(flow is null && size is null) | |||
| { | |||
| throw new ValueError("Argument `size` must be provided if argument `flow` is not provided."); | |||
| } | |||
| if(flow is not null && size is not null) | |||
| { | |||
| throw new ValueError("Cannot provide both `flow` and `size` arguments at the same time."); | |||
| } | |||
| if(flow is not null && element_shape is not null) | |||
| { | |||
| throw new ValueError("Cannot provide both `flow` and `element_shape` arguments at the same time."); | |||
| } | |||
| _dtype = dtype; | |||
| _element_shape = element_shape; | |||
| _infer_shape = infer_shape; | |||
| tf_with(ops.name_scope(name, "TensorArrayV2", new object[] { size, flow }), scope => | |||
| { | |||
| if (flow is null) | |||
| { | |||
| _flow = list_ops.tensor_list_reserve(element_shape, size, dtype, scope.scope_name); | |||
| } | |||
| else | |||
| { | |||
| _flow = flow; | |||
| } | |||
| }); | |||
| _colocate_with_first_write_call = false; | |||
| _colocate_with = null; | |||
| } | |||
| public override TensorArray unstack(Tensor value, string name = null) | |||
| { | |||
| return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _flow, value }), delegate | |||
| { | |||
| value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | |||
| Debug.Assert(value.dtype == _dtype); | |||
| var flow_out = list_ops.tensor_list_from_tensor(value, value.shape.dims.Skip(1).ToArray()); | |||
| return tensor_array_ops.build_ta_with_new_flow(this, flow_out); | |||
| }); | |||
| } | |||
| public TensorArray scatter(Tensor indices, Tensor value, string name = null) | |||
| { | |||
| return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _flow, value, indices }), delegate | |||
| { | |||
| value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | |||
| Debug.Assert(value.dtype == _dtype); | |||
| var flow_out = list_ops.tensor_list_scatter(value, indices, _element_shape, _flow); | |||
| return tensor_array_ops.build_ta_with_new_flow(this, flow_out); | |||
| }); | |||
| } | |||
| public override Tensor read<T>(T index, string name = null) | |||
| { | |||
| if(index is Tensor tensor) | |||
| { | |||
| return read(tensor, name); | |||
| } | |||
| else | |||
| { | |||
| throw new TypeError("Please use non-generic method instead."); | |||
| } | |||
| } | |||
| public Tensor read(Tensor index, string name = null) | |||
| { | |||
| return tf_with(tf.name_scope(name, "TensorArrayV2Read", new object[] { _flow, index }), scope => | |||
| { | |||
| return list_ops.tensor_list_get_item(_flow, index, _dtype, _element_shape, name); | |||
| }); | |||
| } | |||
| public override TensorArray write(Tensor index, Tensor value, string name = null) | |||
| { | |||
| return tf_with(ops.name_scope(name, "TensorArrayV2Write", new { _flow, index, value }), delegate | |||
| { | |||
| value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | |||
| Debug.Assert(value.dtype == _dtype); | |||
| var flow_out = list_ops.tensor_list_set_item(_flow, index, value, _dynamic_size, name); | |||
| return tensor_array_ops.build_ta_with_new_flow(this, flow_out); | |||
| }); | |||
| } | |||
| public override TensorArray write<T>(int index, T value, string name = null) | |||
| { | |||
| var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | |||
| var index_tensor = ops.convert_to_tensor(index, name: "index"); | |||
| return write(index_tensor, value_tensor); | |||
| } | |||
| private Tensor size(string name = null) | |||
| { | |||
| if(!_dynamic_size && _size is not null) | |||
| { | |||
| return ops.convert_to_tensor(_size, dtypes.int32); | |||
| } | |||
| else | |||
| { | |||
| return gen_list_ops.tensor_list_length(_flow, name); | |||
| } | |||
| } | |||
| public override Tensor stack(string name = null) | |||
| { | |||
| return tf_with(ops.name_scope(name, "TensorArrayV2Stack", _flow), delegate | |||
| { | |||
| int ta_size; | |||
| if(!_dynamic_size && (_size is not null)) | |||
| { | |||
| var size_tensor = tensor_util.constant_value(_size); | |||
| ta_size = size_tensor is null ? -1 : (int)size_tensor; | |||
| } | |||
| else | |||
| { | |||
| ta_size = -1; | |||
| } | |||
| var value = list_ops.tensor_list_stack(_flow, _dtype, ta_size, _element_shape); | |||
| return value; | |||
| }); | |||
| } | |||
| public override Tensor gather(Tensor indices, string name = null) | |||
| { | |||
| return list_ops.tensor_list_gather(_flow, indices, _dtype, _element_shape, name); | |||
| } | |||
| } | |||
| } | |||
| @@ -119,6 +119,27 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| public static Tensor zeros(Tensors shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | |||
| { | |||
| dtype = dtype.as_base_dtype(); | |||
| Tensor shapeTensor; | |||
| if(shape.Length > 1) | |||
| { | |||
| shapeTensor = ops.convert_to_tensor(shape, dtypes.int32); | |||
| if(shapeTensor.ndim > 1) | |||
| { | |||
| shapeTensor = array_ops.reshape(shapeTensor, new Shape(-1)); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| shapeTensor = shape[0]; | |||
| } | |||
| var output = fill(shapeTensor, array_ops.constant(0, dtype), name); | |||
| Debug.Assert(output.dtype.as_base_dtype() == dtype); | |||
| return output; | |||
| } | |||
| public static Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0) | |||
| { | |||
| return tf_with(ops.name_scope(name, values: new { tensor, mask }), delegate | |||
| @@ -307,6 +328,9 @@ namespace Tensorflow | |||
| public static Tensor fill<T>(Shape dims, T value, string name = null) | |||
| => gen_array_ops.fill(dims, ops.convert_to_tensor(value), name: name); | |||
| public static Tensor fill<T>(Tensor dims, T value, string name = null) | |||
| => gen_array_ops.fill(dims, ops.convert_to_tensor(value), name: name); | |||
| /// <summary> | |||
| /// Returns the rank of a tensor. | |||
| /// </summary> | |||
| @@ -947,38 +971,70 @@ namespace Tensorflow | |||
| }); | |||
| } | |||
| public static Tensor[] split(Tensor value, Tensor size_splits, int axis, int num = -1, | |||
| string name = "split") | |||
| /// <summary> | |||
| /// Transposes last two dimensions of tensor `a`. | |||
| /// For example: | |||
| /// <code> python | |||
| /// x = tf.constant([[1, 2, 3], [4, 5, 6]]) | |||
| /// tf.matrix_transpose(x) # [[1, 4], | |||
| /// # [2, 5], | |||
| /// # [3, 6]] | |||
| /// </code> | |||
| /// Matrix with two batch dimensions. | |||
| /// x.shape is [1, 2, 3, 4] | |||
| /// tf.linalg.matrix_transpose(x) is shape [1, 2, 4, 3] | |||
| /// </summary> | |||
| /// <param name="a"></param> | |||
| /// <param name="name"></param> | |||
| /// <param name="conjugate"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="ValueError"></exception> | |||
| public static Tensor matrix_transpose(Tensor a, string name = "matrix_transpose", bool conjugate = false) | |||
| { | |||
| if (num == -1) | |||
| num = (int)size_splits.shape[0]; | |||
| return gen_array_ops.split_v(value, size_splits, tf.convert_to_tensor(axis), num, name: name); | |||
| return tf_with(ops.name_scope(name, "transpose", new { a }), scope => | |||
| { | |||
| var a_shape = a.shape; | |||
| var ndims = a.shape.ndim; | |||
| Axis perm; | |||
| if(ndims != 0) | |||
| { | |||
| if (ndims < 2) | |||
| { | |||
| throw new ValueError("Argument `a` should be a (batch) matrix with rank " + | |||
| $">= 2. Received `a` = {a} with shape: {a_shape}"); | |||
| } | |||
| perm = new Axis(Enumerable.Range(0, ndims - 2).Concat(new int[] { ndims - 1, ndims - 2 }).ToArray()); | |||
| } | |||
| else | |||
| { | |||
| var a_rank = a.rank; | |||
| perm = new Axis(Enumerable.Range(0, a_rank - 2).Concat(new int[] { a_rank - 1, a_rank - 2 }).ToArray()); | |||
| } | |||
| return transpose(a, perm:perm, conjugate:conjugate); | |||
| }); | |||
| } | |||
| public static Tensor[] split<T>(Tensor value, int num_split, T axis, | |||
| public static Tensor[] split(Tensor value, int num_or_size_splits, Tensor axis = null, | |||
| string name = "split") | |||
| { | |||
| var size_splits = ops.convert_to_tensor(num_split); | |||
| return gen_array_ops.split(split_dim: axis, value: value, num_split: num_or_size_splits, name); | |||
| } | |||
| if (tf.Context.executing_eagerly()) | |||
| public static Tensor[] split(Tensor value, int[] num_or_size_splits, Tensor axis = null, int num = -1, | |||
| string name = "split") | |||
| { | |||
| if(num_or_size_splits.Length == 0) | |||
| { | |||
| return split_eager_fallback(axis, value, num_split: num_split, name: name, ctx: tf.Context); | |||
| throw new ValueError("Rank-0 tensors are not supported as the num_or_size_splits argument to split."); | |||
| } | |||
| var size_splits = ops.convert_to_tensor(num_or_size_splits); | |||
| var _op = tf.OpDefLib._apply_op_helper("Split", name, new { split_dim = axis, value, num_split }); | |||
| return _op.outputs; | |||
| } | |||
| private static Tensor[] split_eager_fallback<Ta, Tv>(Ta axis, Tv value, int num_split, string name, Context ctx = null) | |||
| { | |||
| var (_attr_T, input) = tf.Runner.ArgsToMatchingEager(ctx, args: new object[] { value }); | |||
| var axis_tensor = ops.convert_to_tensor(axis, dtype: TF_DataType.TF_INT32); | |||
| var _inputs_flat = new List<Tensor> { axis_tensor }; | |||
| _inputs_flat.AddRange(input); | |||
| var _attrs = new object[] { "num_split", num_split, "T", _attr_T }; | |||
| if(num == -1) | |||
| { | |||
| num = (int)size_splits.shape[0]; | |||
| } | |||
| return tf.Runner.Execute(ctx, "Split", num_split, _inputs_flat.ToArray(), _attrs, name: name); | |||
| return gen_array_ops.split_v(value: value, size_splits: size_splits, split_dim: axis, num_split: num, name: name); | |||
| } | |||
| public static Tensor slice(Tensor input, Tensor[] begin, Tensor[] size, string name = null) | |||
| @@ -675,16 +675,17 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| public static Tensor[] while_loop(Func<Tensor[], Tensor> cond, | |||
| Func<Tensor[], Tensor[]> body, | |||
| Tensor[] loop_vars, | |||
| public static Tensors while_loop(Func<Tensors, Tensor> cond, | |||
| Func<Tensors, Tensors> body, | |||
| Tensors loop_vars, | |||
| int parallel_iterations = 10, | |||
| string name = null) | |||
| { | |||
| var executing_eagerly = tf.Context.executing_eagerly(); | |||
| if (!executing_eagerly) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| return while_v2.while_loop(cond, body, loop_vars, parallel_iterations: parallel_iterations, | |||
| name: name); | |||
| } | |||
| return tf_with(ops.name_scope("name", "while"), delegate | |||
| @@ -16,12 +16,20 @@ | |||
| using System; | |||
| using System.Linq; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.Graphs; | |||
| using Tensorflow.Operations; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| public class control_flow_util | |||
| { | |||
| public static readonly bool ENABLE_CONTROL_FLOW_V2 = !string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2") != "0" || | |||
| (!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2") != "0") || | |||
| (!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_COND_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_COND_V2") != "0") || | |||
| (!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_WHILE_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_WHILE_V2") != "0") || | |||
| (!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_TENSOR_ARRAY_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_TENSOR_ARRAY_V2") != "0"); | |||
| /// <summary> | |||
| /// Return true if `op` is an Exit. | |||
| /// </summary> | |||
| @@ -196,5 +204,74 @@ namespace Tensorflow | |||
| } | |||
| return null; | |||
| } | |||
| public static bool EnableControlFlowV2(Graph graph) | |||
| { | |||
| return ENABLE_CONTROL_FLOW_V2 || graph.building_function && (graph is not FuncGraph func || func.captures.Length == 0); | |||
| } | |||
| public static string create_new_tf_function(FuncGraph func_graph) | |||
| { | |||
| var func = new EagerDefinedFunction(func_graph.Name, func_graph, func_graph.Inputs, func_graph.Outputs, new Dictionary<string, AttrValue>()); | |||
| func.AddToGraph(func_graph); | |||
| return func_graph.Name; | |||
| } | |||
| public static (Operation, Tensor[]) get_op_and_outputs(Tensor[] inputs) | |||
| { | |||
| if(inputs.Length == 0) | |||
| { | |||
| return (null, new Tensor[0]); | |||
| } | |||
| else | |||
| { | |||
| return (inputs[0], inputs); | |||
| } | |||
| } | |||
| public static Tensor[] run_as_function_for_tape_gradients(Func<Tensor[], Tensor[]> make_op, Tensor[] inputs) | |||
| { | |||
| if(gradients_util.PossibleTapeGradientTypes(inputs) == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER | |||
| && !(ops.get_default_graph().building_function)) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| else | |||
| { | |||
| return make_op(inputs); | |||
| } | |||
| } | |||
| public static string unique_fn_name(string scope, string name) | |||
| { | |||
| return $"{scope}{name}_{ops.uid()}".Replace("/", "_"); | |||
| } | |||
| public static bool output_all_intermediates() | |||
| { | |||
| if (in_defun()) | |||
| { | |||
| return false; | |||
| } | |||
| if(tf.Context.FunctionCallOptions.ExecutorType == "SINGLE_THREADED_EXECUTOR") | |||
| { | |||
| return false; | |||
| } | |||
| // TODO(Rinne): check this after refactoring keras building. | |||
| return false; | |||
| } | |||
| public static bool in_defun() | |||
| { | |||
| if (tf.Context.executing_eagerly()) | |||
| { | |||
| return false; | |||
| } | |||
| var graph = ops.get_default_graph(); | |||
| // TODO(Rinne): CondBranchFuncGraph, WhileBodyFuncGraph, WhileCondFuncGraph | |||
| return graph is FuncGraph; | |||
| } | |||
| } | |||
| } | |||
| @@ -1778,10 +1778,10 @@ new_height, new_width"); | |||
| { | |||
| // a_y_min: [0], a_x_min: [1], a_y_max: [2], a_x_max[3] | |||
| var a_xy_minmax = array_ops.split( | |||
| value: boxes_a, num_split: 4, axis: 2); | |||
| value: boxes_a, num_or_size_splits: 4, axis: ops.convert_to_tensor(2)); | |||
| // b_y_min: [0], b_x_min: [1], b_y_max: [2], b_x_max[3] | |||
| var b_xy_minmax = array_ops.split( | |||
| value: boxes_b, num_split: 4, axis: 2); | |||
| value: boxes_b, num_or_size_splits: 4, axis: ops.convert_to_tensor(2)); | |||
| var i_xmin = math_ops.maximum( | |||
| a_xy_minmax[1], array_ops.transpose(b_xy_minmax[1], new[] { 0, 2, 1 })); | |||
| @@ -1943,7 +1943,7 @@ new_height, new_width"); | |||
| using (ops.name_scope("canonicalize_coordinates")) | |||
| { | |||
| // y_1 = [0], x_1 = [1], y_2 = [2], x_2 = [3] | |||
| var yx = array_ops.split(value: boxes, num_split: 4, axis: 2); | |||
| var yx = array_ops.split(value: boxes, num_or_size_splits: 4, axis: ops.convert_to_tensor(2)); | |||
| var y_1_is_min = math_ops.reduce_all( | |||
| gen_math_ops.less_equal(yx[0][0, 0, 0], yx[2][0, 0, 0])); | |||
| var y_minmax = control_flow_ops.cond( | |||
| @@ -0,0 +1,111 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Eager; | |||
| namespace Tensorflow.Operations | |||
| { | |||
| internal class list_ops | |||
| { | |||
| private static void _set_handle_data(Tensor list_handle, Shape element_shape, TF_DataType element_dtype) | |||
| { | |||
| if(list_handle is EagerTensor eagerTensor) | |||
| { | |||
| var handle_data = new CppShapeInferenceResult.Types.HandleData(); | |||
| handle_data.IsSet = true; | |||
| handle_data.ShapeAndType.Add(new CppShapeInferenceResult.Types.HandleShapeAndType() | |||
| { | |||
| Shape = element_shape.as_proto(), | |||
| Dtype = element_dtype.as_datatype_enum(), | |||
| Type = new FullTypeDef() { TypeId = FullTypeId.TftArray } | |||
| }); | |||
| list_handle.HandleData = handle_data; | |||
| } | |||
| } | |||
| private static Tensor _build_element_shape(Shape? shape) | |||
| { | |||
| if(shape is null || shape.IsNull) | |||
| { | |||
| return ops.convert_to_tensor(-1); | |||
| } | |||
| else | |||
| { | |||
| return ops.convert_to_tensor(shape); | |||
| } | |||
| } | |||
| public static Tensor tensor_list_reserve(Shape? shape, Tensor num_elements, TF_DataType element_dtype, string name = null) | |||
| { | |||
| var result = gen_list_ops.tensor_list_reserve(_build_element_shape(shape), num_elements, element_dtype, name); | |||
| _set_handle_data(result, shape, element_dtype); | |||
| return result; | |||
| } | |||
| public static Tensor tensor_list_from_tensor(Tensor tensor, Shape element_shape, string? name = null) | |||
| { | |||
| var result = gen_list_ops.tensor_list_from_tensor(tensor, _build_element_shape(element_shape), name); | |||
| _set_handle_data(result, tensor.shape, tensor.dtype); | |||
| return result; | |||
| } | |||
| public static Tensor tensor_list_get_item(Tensor input_handle, Tensor index, TF_DataType element_dtype, | |||
| Shape? element_shape = null, string? name = null) | |||
| { | |||
| return gen_list_ops.tensor_list_get_item(input_handle, index, _build_element_shape(element_shape), | |||
| element_dtype, name); | |||
| } | |||
| public static Tensor tensor_list_set_item(Tensor input_handle, Tensor index, Tensor item, | |||
| bool resize_if_index_out_of_bounds = false, string? name = null) | |||
| { | |||
| if (resize_if_index_out_of_bounds) | |||
| { | |||
| var input_list_size = gen_list_ops.tensor_list_length(input_handle); | |||
| input_handle = control_flow_ops.cond(index >= input_list_size, | |||
| () => gen_list_ops.tensor_list_resize(input_handle, index + 1), | |||
| () => input_handle); | |||
| } | |||
| var output_handle = gen_list_ops.tensor_list_set_item(input_handle, index, item, name); | |||
| handle_data_util.copy_handle_data(input_handle, output_handle); | |||
| return output_handle; | |||
| } | |||
| public static Tensor tensor_list_stack(Tensor input_handle, TF_DataType element_dtype, int num_elements = -1, | |||
| Shape? element_shape = null, string? name = null) | |||
| { | |||
| return gen_list_ops.tensor_list_stack(input_handle, _build_element_shape(element_shape), element_dtype, num_elements, name); | |||
| } | |||
| public static Tensor tensor_list_gather(Tensor input_handle, Tensor indices, TF_DataType element_dtype, | |||
| Shape? element_shape = null, string? name = null) | |||
| { | |||
| return gen_list_ops.tensor_list_gather(input_handle, indices, _build_element_shape(element_shape), element_dtype, name); | |||
| } | |||
| public static Tensor tensor_list_scatter(Tensor tensor, Tensor indices, Shape? element_shape = null, Tensor? input_handle = null, | |||
| string? name = null) | |||
| { | |||
| if(input_handle is not null) | |||
| { | |||
| var output_handle = gen_list_ops.tensor_list_scatter_into_existing_list(input_handle, tensor, indices, name); | |||
| handle_data_util.copy_handle_data(input_handle, output_handle); | |||
| return output_handle; | |||
| } | |||
| else | |||
| { | |||
| var output_handle = gen_list_ops.tensor_list_scatter_v2(tensor, indices, _build_element_shape(element_shape), | |||
| constant_op.constant(-1), name); | |||
| _set_handle_data(output_handle, element_shape, tensor.dtype); | |||
| return output_handle; | |||
| } | |||
| } | |||
| public static Tensor empty_tensor_list(Shape? element_shape, TF_DataType element_dtype, int max_num_elements = -1, | |||
| string? name = null) | |||
| { | |||
| return gen_list_ops.empty_tensor_list(_build_element_shape(element_shape), element_dtype: element_dtype, | |||
| max_num_elements: ops.convert_to_tensor(max_num_elements, dtype: dtypes.int32), name: name); | |||
| } | |||
| } | |||
| } | |||
| @@ -30,7 +30,7 @@ namespace Tensorflow | |||
| name: name); | |||
| return tf.Context.ExecuteOp("PrintV2", name, new ExecuteOpArgs(formatted_string) | |||
| .SetAttributes(new { output_stream, end })); | |||
| .SetAttributes(new { output_stream, end })).SingleOrNull; | |||
| } | |||
| } | |||
| } | |||
| @@ -44,7 +44,7 @@ namespace Tensorflow | |||
| { | |||
| sorted = true | |||
| })); | |||
| return indices; | |||
| return indices.Single; | |||
| } | |||
| public static Tensor sort(Tensor values, Axis axis, string direction = "ASCENDING", string? name = null) | |||
| @@ -13,11 +13,23 @@ namespace Tensorflow | |||
| /// <returns></returns> | |||
| public static TensorArray build_ta_with_new_flow(TensorArray old_ta, Tensor flow) | |||
| { | |||
| var new_ta = tf.TensorArray( | |||
| dtype: old_ta.dtype, | |||
| infer_shape: old_ta.infer_shape, | |||
| if (!tf.Context.executing_eagerly() && old_ta is not _GraphTensorArrayV2 && control_flow_util.EnableControlFlowV2(ops.get_default_graph())) | |||
| { | |||
| throw new NotImplementedException("Attempting to build a graph-mode TF2-style " | |||
| + "TensorArray from either an eager-mode " | |||
| + "TensorArray or a TF1-style TensorArray. " | |||
| + "This is not currently supported. You may be " | |||
| + "attempting to capture a TensorArray " | |||
| + "inside a tf.function or tf.data map function. " | |||
| + "Instead, construct a new TensorArray inside " | |||
| + "the function."); | |||
| } | |||
| var new_ta = TensorArray.Create(old_ta.dtype, handle: old_ta.handle, flow: flow, infer_shape: old_ta.infer_shape, | |||
| colocate_with_first_write_call: old_ta.colocate_with_first_write_call); | |||
| new_ta._dynamic_size = old_ta._dynamic_size; | |||
| new_ta._size = old_ta._size; | |||
| new_ta._colocate_with = old_ta._colocate_with; | |||
| new_ta._element_shape = old_ta._element_shape; | |||
| return new_ta; | |||
| } | |||
| @@ -0,0 +1,401 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Text; | |||
| using Tensorflow.Common.Extensions; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Framework; | |||
| using Tensorflow.Framework.Models; | |||
| using Tensorflow.Graphs; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Operations | |||
| { | |||
| class _OperationWithOutputs : Operation | |||
| { | |||
| public _OperationWithOutputs(IntPtr handle, Graph g = null) | |||
| { | |||
| _handle = handle; | |||
| _graph = g; | |||
| _outputs = null; | |||
| g._add_op(this); | |||
| } | |||
| } | |||
| internal class while_v2 | |||
| { | |||
| public static Tensor[] while_loop(Func<Tensors, Tensor> cond, | |||
| Func<Tensors, Tensors> body, | |||
| Tensors loop_vars, | |||
| int maximum_iterations = -1, | |||
| int parallel_iterations = 10, | |||
| string name = null, | |||
| bool back_prop = true, | |||
| bool return_same_structure = true) | |||
| { | |||
| var orig_loop_vars = loop_vars; | |||
| var flat_orig_loop_vars = orig_loop_vars.Flatten().ToArray(); | |||
| int len_orig_loop_vars = orig_loop_vars.Length; | |||
| loop_vars = _tensor_array_to_flow(loop_vars); | |||
| loop_vars = Nest.MapStructure(x => _convert_to_tensor_or_indexed_slices(x, TF_DataType.DtInvalid, null), loop_vars).ToTensors(); | |||
| var loop_vars_signature = Nest.MapStructure(x => new TensorSpec(x.shape, x.dtype), _tensor_array_to_flow(loop_vars)); | |||
| var flat_shape_invariants = Nest.Flatten(loop_vars_signature).Select(x => x.shape).ToArray(); | |||
| if(string.IsNullOrEmpty(name)) | |||
| { | |||
| name = "while"; | |||
| } | |||
| return tf_with<ITensorFlowObject, Tensor[]>(ops.name_scope(name), nameScopeWhile => | |||
| { | |||
| string scope = (nameScopeWhile as ops.NameScope).scope_name; | |||
| string cond_name = control_flow_util.unique_fn_name(scope, "cond"); | |||
| string body_name = control_flow_util.unique_fn_name(scope, "body"); | |||
| var maximum_iterations_loop_var = _build_maximum_iterations_loop_var(maximum_iterations); | |||
| var loop_counter = constant_op.constant(0, maximum_iterations == -1 ? TF_DataType.DtInvalid : maximum_iterations_loop_var.dtype, | |||
| name: "loop_counter"); | |||
| loop_vars = new Tensor[] { loop_counter, maximum_iterations_loop_var }.Concat(loop_vars).ToArray(); | |||
| var func_graph_signature = new TensorSpec[] {TensorSpec.FromTensor(loop_counter),TensorSpec.FromTensor(maximum_iterations_loop_var)} | |||
| .Concat(loop_vars_signature.Flatten()).ToArray(); | |||
| // TODO(Rinne): possible wrong implemenation here. | |||
| var add_control_dependencies = false; | |||
| object[] wrapped_cond(object[] inputs) | |||
| { | |||
| Tensor loop_counter = (Tensor)inputs[0]; | |||
| Tensor maximum_iterations_arg = (Tensor)inputs[1]; | |||
| Tensor[] args = inputs.Skip(2).Select(x => (Tensor)x).ToArray(); | |||
| var pred = cond(_pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, args)); | |||
| if(pred.shape.IsNull || pred.shape.ndim > 0) | |||
| { | |||
| pred = array_ops.squeeze(pred); | |||
| } | |||
| if(maximum_iterations == -1) | |||
| { | |||
| return new object[] { pred }; | |||
| } | |||
| else | |||
| { | |||
| return new object[] { math_ops.logical_and(loop_counter < maximum_iterations_arg, pred) }; | |||
| } | |||
| } | |||
| var cond_graph = FuncGraph.func_graph_from_func(cond_name, wrapped_cond, null, | |||
| null, signature: func_graph_signature, add_control_dependencies: add_control_dependencies); | |||
| bool stateful_parallelism = false; | |||
| object[] wrapped_body(object[] inputs) | |||
| { | |||
| Tensor loop_counter = (Tensor)inputs[0]; | |||
| Tensor maximum_iterations_arg = (Tensor)inputs[1]; | |||
| Tensor[] args = inputs.Skip(2).Select(x => (Tensor)x).ToArray(); | |||
| _copy_handle_data(loop_vars.Flatten().Skip(2), args); | |||
| foreach(var t in cond_graph.external_captures) | |||
| { | |||
| var graph = (FuncGraph)(ops.get_default_graph()); | |||
| graph.capture(t); | |||
| } | |||
| var outputs = body(_pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, args)); | |||
| outputs = _tensor_array_to_flow(outputs); | |||
| return new object[] { loop_counter + 1, maximum_iterations_arg }.Concat(outputs).ToArray(); | |||
| } | |||
| var body_graph = FuncGraph.func_graph_from_func(body_name, wrapped_body, null, null, func_graph_signature, | |||
| add_control_dependencies: add_control_dependencies, acd_record_initial_resource_uses: stateful_parallelism); | |||
| // TODO(Rinne): possible wrong implementation here. | |||
| NestList<Tensors> loop_vars_list = new(new Tensors[] { loop_vars, body_graph.external_captures.ToTensors() }); | |||
| body_graph.Outputs.AddRange(body_graph.internal_captures); | |||
| cond_graph.as_default(); | |||
| int num_cond_captures = cond_graph.external_captures.Length; | |||
| Debug.Assert(cond_graph.external_captures.SequenceEqual(body_graph.external_captures.Take(num_cond_captures).ToArray())); | |||
| _duplicate_body_captures_in_cond(cond_graph, body_graph.external_captures.Skip(num_cond_captures).ToArray()); | |||
| cond_graph.Exit(); | |||
| int first_loop_var_index = 2; | |||
| int num_flattened_oututs = orig_loop_vars.Length; | |||
| int num_original_outputs = body_graph.Outputs.Length; | |||
| if (back_prop && control_flow_util.output_all_intermediates()) | |||
| { | |||
| var intermediate_tensors = _get_intermediates(body_graph); | |||
| foreach(var intermediate_tensor in intermediate_tensors) | |||
| { | |||
| var tensor_list = list_ops.empty_tensor_list(intermediate_tensor.shape, intermediate_tensor.dtype, maximum_iterations); | |||
| loop_vars_list.Values.Add(tensor_list); | |||
| cond_graph.as_default(); | |||
| cond_graph.capture(tensor_list); | |||
| cond_graph.Exit(); | |||
| body_graph.as_default(); | |||
| var appended_tensor_list = gen_ops.tensor_list_push_back(tensor_list, intermediate_tensor); | |||
| body_graph.Outputs.Add(appended_tensor_list); | |||
| body_graph.Exit(); | |||
| } | |||
| } | |||
| List<Tensor> flattened_loop_vars = new(); | |||
| foreach(var item in loop_vars_list.Values) | |||
| { | |||
| flattened_loop_vars.AddRange(item.Flatten()); | |||
| } | |||
| // skip the check | |||
| // TODO(Rinne): deal with control dependencies | |||
| var output_shapes = body_graph.Outputs.Select(t => t.shape).ToArray(); | |||
| var span = new Span<Shape>(output_shapes).Slice(first_loop_var_index, num_flattened_oututs); | |||
| for(int i = 0; i < span.Length; i++) | |||
| { | |||
| span[i] = flat_shape_invariants[i]; | |||
| } | |||
| Tensor[] outputs = _build_while_op(flattened_loop_vars.ToArray(), cond_graph, body_graph, output_shapes, parallel_iterations, | |||
| (nameScopeWhile as ops.NameScope).scope_name, num_original_outputs, stateful_parallelism); | |||
| if (!ops.get_default_graph().building_function) | |||
| { | |||
| outputs = outputs.Select(t => array_ops.identity(t)).ToArray(); | |||
| } | |||
| var output_loop_vars = outputs.Skip(first_loop_var_index).Take(num_flattened_oututs).ToArray(); | |||
| if (!back_prop) | |||
| { | |||
| output_loop_vars = output_loop_vars.Select(t => array_ops.stop_gradient(t)).ToArray(); | |||
| } | |||
| outputs = _pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, output_loop_vars); | |||
| return outputs; | |||
| }); | |||
| } | |||
| private static Tensors _tensor_array_to_flow(Tensors loop_vars) | |||
| { | |||
| if(loop_vars.NestType == NestType.Node) | |||
| { | |||
| if(loop_vars.NodeValue is FakeTensorByTensorArray fake) | |||
| { | |||
| return new Tensors(fake.TensorArray.flow); | |||
| } | |||
| else | |||
| { | |||
| return new Tensors(loop_vars.NodeValue!); | |||
| } | |||
| } | |||
| else if(loop_vars.NestType == NestType.List) | |||
| { | |||
| List<INestStructure<Tensor>> list = new(); | |||
| foreach(var item in loop_vars.ListValue!) | |||
| { | |||
| if(item.NestType == NestType.Node) | |||
| { | |||
| var nested = item.AsNest(); | |||
| if (nested.NodeValue is FakeTensorByTensorArray fake) | |||
| { | |||
| list.Add(new Nest<Tensor>(fake.TensorArray.flow)); | |||
| } | |||
| else | |||
| { | |||
| list.Add(new Nest<Tensor>(nested.NodeValue!)); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| list.Add(new Nest<Tensor>(item.AsNest())); | |||
| } | |||
| } | |||
| return Tensors.FromNest(new Nest<Tensor>(list)); | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| private static Tensor[] _build_while_op(Tensor[] loop_vars, FuncGraph cond_graph, FuncGraph body_graph, | |||
| Shape[] output_shapes, int parallel_iterations, string name, int num_original_outputs, bool stateful_parallelism) | |||
| { | |||
| var cond_stateful_ops = cond_graph.get_operations().Select(x => x.op); | |||
| var body_stateful_ops = body_graph.get_operations().Select(x => x.op); | |||
| bool is_stateful = cond_stateful_ops.Count() > 0 || body_stateful_ops.Count() > 0; | |||
| Tensor[] _make_op(Tensor[] inputs) | |||
| { | |||
| Tensor[] outputs; | |||
| if (is_stateful) | |||
| { | |||
| outputs = gen_functional_ops._while( | |||
| inputs, | |||
| control_flow_util.create_new_tf_function(cond_graph), | |||
| control_flow_util.create_new_tf_function(body_graph), | |||
| output_shapes, | |||
| parallel_iterations, | |||
| name | |||
| ); | |||
| } | |||
| else | |||
| { | |||
| outputs = gen_functional_ops.stateless_while( | |||
| inputs, | |||
| control_flow_util.create_new_tf_function(cond_graph), | |||
| control_flow_util.create_new_tf_function(body_graph), | |||
| output_shapes, | |||
| parallel_iterations, | |||
| name | |||
| ); | |||
| } | |||
| var (while_op, tensors) = control_flow_util.get_op_and_outputs(outputs); | |||
| _copy_handle_data(body_graph.Outputs, tensors); | |||
| _set_read_only_resource_inputs_attr(while_op, new FuncGraph[]{cond_graph, body_graph}); | |||
| while_op._set_attr("_num_original_outputs", new AttrValue() { I = num_original_outputs }); | |||
| while_op._set_attr("_stateful_parallelism", new AttrValue() { B = stateful_parallelism }); | |||
| cond_graph.outer_graph = ops.get_default_graph(); | |||
| body_graph.outer_graph = ops.get_default_graph(); | |||
| // TODO(Rinne): set the two graphs to while_op | |||
| return tensors; | |||
| } | |||
| return control_flow_util.run_as_function_for_tape_gradients(_make_op, loop_vars); | |||
| } | |||
| /// <summary> | |||
| /// Sets the list of resource inputs which are read-only. This is used by AutomaticControlDependencies. | |||
| /// </summary> | |||
| /// <param name="op"></param> | |||
| /// <param name="branch_graphs"></param> | |||
| private static void _set_read_only_resource_inputs_attr(Operation op, FuncGraph[] branch_graphs) | |||
| { | |||
| List<int> read_only_indices = Enumerable.Range(0, op.inputs.Length).ToList(); | |||
| foreach(var branch_graph in branch_graphs) | |||
| { | |||
| if (read_only_indices.Count == 0) | |||
| { | |||
| break; | |||
| } | |||
| var branch_read_only_indices = auto_control_deps_utils.get_read_only_resource_input_indices_graph(branch_graph); | |||
| read_only_indices = read_only_indices.Intersect(branch_read_only_indices).ToList(); | |||
| } | |||
| AttrValue.Types.ListValue listValue = new(); | |||
| listValue.I.AddRange(read_only_indices.OrderBy(x => x).Select(x => (long)x)); | |||
| op._set_attr(auto_control_deps_utils.READ_ONLY_RESOURCE_INPUTS_ATTR, new AttrValue() | |||
| { | |||
| List = listValue | |||
| }); | |||
| } | |||
| private static Tensors _pack_sequence_as<T>(INestStructure<T> loop_vars_signature, Tensor[] flat_orig_loop_vars, Tensor[] loop_vars) | |||
| { | |||
| var flattened_loop_vars = zip(loop_vars, flat_orig_loop_vars).Select<(Tensor, Tensor), Tensor>(item => | |||
| { | |||
| var (flow, y) = item; | |||
| if (y is FakeTensorByTensorArray ta) | |||
| { | |||
| return new FakeTensorByTensorArray(tensor_array_ops.build_ta_with_new_flow(ta.TensorArray, flow)); | |||
| } | |||
| else | |||
| { | |||
| return flow; | |||
| } | |||
| }).ToArray(); | |||
| return Nest.PackSequenceAs(loop_vars_signature, flattened_loop_vars).ToTensors(); | |||
| } | |||
| private static Tensor[] _get_intermediates(FuncGraph func_graph) | |||
| { | |||
| List<Tensor> intermediates = new(); | |||
| var reversed_captures = func_graph.captures.ToDictionary(x => x.Item2, x => x.Item1); | |||
| foreach(var op in func_graph.get_operations()) | |||
| { | |||
| Debug.Assert(op is Operation); | |||
| var oper = (Operation)op; | |||
| if(oper.type == "Identity" || oper.type == "MutexLock") | |||
| { | |||
| continue; | |||
| } | |||
| foreach(var o in op.outputs) | |||
| { | |||
| if(o != func_graph.Inputs[0] && o.dtype != dtypes.resource && !reversed_captures.ContainsKey(o)) | |||
| { | |||
| intermediates.Add(o); | |||
| } | |||
| } | |||
| } | |||
| return intermediates.ToArray(); | |||
| } | |||
| private static void _duplicate_body_captures_in_cond(FuncGraph cond_graph, Tensor[] body_graph_captures) | |||
| { | |||
| var types = body_graph_captures.Select(t => t.dtype).ToList(); | |||
| var c_graph = cond_graph.c_graph; | |||
| var placeholders = types.Select(x => CreatePlaceholder(c_graph, _build_cond_placeholders_name_prefix(cond_graph), x)).ToList(); | |||
| var placeholder_ops = placeholders.Select(ph => new _OperationWithOutputs(ph.oper, cond_graph)).ToList(); | |||
| List<Tensor> tensors = new(); | |||
| foreach(var (op, ph, dtype) in zip(placeholder_ops, placeholders, types)) | |||
| { | |||
| var tensor = Tensor._create_with_tf_output(op, 0, dtype, ph); | |||
| op._outputs = new Tensor[] { tensor }; | |||
| tensors.Add(tensor); | |||
| } | |||
| var tuples = zip(body_graph_captures, tensors).ToList(); | |||
| var keys = body_graph_captures.Select(t => t.Id).ToList(); | |||
| cond_graph._captures.Update(zip(keys, tuples).ToDictionary(x => x.Item1, x => x.Item2)); | |||
| cond_graph.Inputs.AddRange(tensors); | |||
| } | |||
| private static TF_Output CreatePlaceholder(SafeGraphHandle graph, string name, TF_DataType dtype) | |||
| { | |||
| var desc = c_api.TF_NewOperation(graph, "Placeholder", name); | |||
| c_api.TF_SetAttrType(desc, "dtype", dtype); | |||
| var op = c_api.TF_FinishOperation(desc, tf.Status); | |||
| tf.Status.Check(true); | |||
| var output = new TF_Output(); | |||
| output.oper = op; | |||
| output.index = 0; | |||
| return output; | |||
| } | |||
| private static string _build_cond_placeholders_name_prefix(FuncGraph cond_graph) | |||
| { | |||
| return cond_graph.unique_name(cond_graph.Name + "___redundant_placeholder"); | |||
| } | |||
| private static Tensor _convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype, | |||
| string name) | |||
| { | |||
| return ops.convert_to_tensor(value, dtype, name, false); | |||
| } | |||
| private static Tensor _build_maximum_iterations_loop_var(int maximum_iterations = -1) | |||
| { | |||
| return ops.convert_to_tensor(maximum_iterations, dtypes.int32, "maximum_iterations"); | |||
| } | |||
| private static void _copy_handle_data(IEnumerable<Tensor> src_tensors, IEnumerable<Tensor> dst_tensors) | |||
| { | |||
| foreach(var (src_t, dst_t) in zip(src_tensors, dst_tensors)) | |||
| { | |||
| handle_data_util.copy_handle_data(src_t, dst_t); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -17,6 +17,7 @@ | |||
| using System; | |||
| using System.Diagnostics; | |||
| using System.Runtime.CompilerServices; | |||
| using Tensorflow.Exceptions; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.c_api; | |||
| @@ -88,7 +89,7 @@ namespace Tensorflow | |||
| case TF_Code.TF_INVALID_ARGUMENT: | |||
| throw new InvalidArgumentError(message); | |||
| default: | |||
| throw new TensorflowException(message); | |||
| throw new NotOkStatusException(message); | |||
| } | |||
| } | |||
| } | |||
| @@ -111,7 +111,12 @@ https://tensorflownet.readthedocs.io</Description> | |||
| <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> | |||
| <PackageReference Include="Newtonsoft.Json" Version="13.0.3" /> | |||
| <PackageReference Include="OneOf" Version="3.0.223" /> | |||
| <PackageReference Include="Protobuf.Text" Version="0.7.0" /> | |||
| <PackageReference Include="Protobuf.Text" Version="0.7.1" /> | |||
| <PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" /> | |||
| </ItemGroup> | |||
| <ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.0'"> | |||
| <PackageReference Include="IsExternalInit" Version="1.0.3" PrivateAssets="all" /> | |||
| <PackageReference Include="System.Memory" Version="4.5.4" PrivateAssets="all" /> | |||
| </ItemGroup> | |||
| </Project> | |||
| @@ -105,6 +105,13 @@ namespace Tensorflow | |||
| _id = ops.uid(); | |||
| } | |||
| internal static Tensor _create_with_tf_output(Operation op, int value_index, TF_DataType dtype, TF_Output tf_output) | |||
| { | |||
| Tensor ret = new Tensor(op, value_index, dtype); | |||
| ret._tf_output = tf_output; | |||
| return ret; | |||
| } | |||
| protected unsafe void InitTensor(Shape shape, TF_DataType dtype) | |||
| { | |||
| _handle = TF_NewTensor(shape, dtype, null); | |||
| @@ -14,7 +14,9 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Operations; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -44,5 +46,27 @@ namespace Tensorflow | |||
| public abstract Tensor stack(string name = null); | |||
| public abstract Tensor gather(Tensor indices, string name = null); | |||
| internal bool _dynamic_size; | |||
| internal Tensor _size; | |||
| internal List<Tensor> _colocate_with; | |||
| internal Shape _element_shape; | |||
| public static TensorArray Create(TF_DataType dtype, Tensor size = null, bool dynamic_size = false, | |||
| bool clear_after_read = true, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | |||
| bool infer_shape = true, Shape? element_shape = null, | |||
| bool colocate_with_first_write_call = true, string name = null) | |||
| { | |||
| if (tf.Context.executing_eagerly() && (flow is null || flow.dtype != dtypes.variant)) | |||
| { | |||
| return new _EagerTensorArray(dtype, size, dynamic_size, clear_after_read, tensor_array_name, handle, flow, | |||
| infer_shape, element_shape, colocate_with_first_write_call, name); | |||
| } | |||
| else | |||
| { | |||
| return new _GraphTensorArrayV2(dtype, size, dynamic_size, clear_after_read, tensor_array_name, handle, flow, | |||
| infer_shape, element_shape, colocate_with_first_write_call, name); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -3,6 +3,9 @@ using System; | |||
| using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Operations; | |||
| using Tensorflow.Common.Extensions; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -13,157 +16,278 @@ namespace Tensorflow | |||
| /// and Tensor[] from Tensors implicitily. | |||
| /// It works for tuple and scalar as well. | |||
| /// </summary> | |||
| public class Tensors : IEnumerable<Tensor>, IDisposable | |||
| public sealed class Tensors : Nest<Tensor>, IDisposable | |||
| { | |||
| List<Tensor> items = new List<Tensor>(); | |||
| public TF_DataType dtype => items.First().dtype; | |||
| public Shape shape => items.First().shape; | |||
| public int rank => items.First().rank; | |||
| public Graph graph => items.First().graph; | |||
| public TF_DataType dtype => this.First().dtype; | |||
| public Shape shape => this.First().shape; | |||
| public int rank => this.First().rank; | |||
| public Graph graph => this.First().graph; | |||
| public bool IsList { get; set; } | |||
| public int Length => items.Count(); | |||
| public int Length => this.Count(); | |||
| /// <summary> | |||
| /// Return a Tensor if `Tensors` has only one tensor, otherwise throw an exception. | |||
| /// </summary> | |||
| public Tensor Single | |||
| { | |||
| get | |||
| { | |||
| if (Length != 1) | |||
| { | |||
| throw new ValueError("Tensors with more than one tensor cannot be " + | |||
| "implicitly converted to Tensor."); | |||
| } | |||
| return this.First(); | |||
| } | |||
| } | |||
| public Tensor this[int index] | |||
| /// <summary> | |||
| /// Return a Tensor if `Tensors` has only one tensor, and return null when `Tensors` is empty, | |||
| /// otherwise throw an exception. | |||
| /// </summary> | |||
| public Tensor? SingleOrNull | |||
| { | |||
| get => items[index]; | |||
| set => items[index] = value; | |||
| get | |||
| { | |||
| if (Length > 1) | |||
| { | |||
| throw new ValueError($"Tensors with {Length} tensor cannot be " + | |||
| "implicitly converted to Tensor."); | |||
| } | |||
| return this.FirstOrDefault(); | |||
| } | |||
| } | |||
| public Tensor this[params string[] slices] | |||
| => items.First()[slices]; | |||
| public Tensors(params Tensor[] tensors) | |||
| => this.First()[slices]; | |||
| internal Tensors(Nest<Tensor> nested) : base(nested) | |||
| { | |||
| items.AddRange(tensors); | |||
| } | |||
| public Tensors(IEnumerable<Tensor> tensors) | |||
| public Tensors(params Tensor[] tensors): base(DealWithConstructorArrayInput(tensors)) | |||
| { | |||
| items.AddRange(tensors); | |||
| } | |||
| public Tensors(NDArray nd) | |||
| public Tensors(IList<Tensor> tensors) : base(tensors.Select(x => new Nest<Tensor>(x))) | |||
| { | |||
| items.Add(ops.convert_to_tensor(nd)); | |||
| } | |||
| public IEnumerator<Tensor> GetEnumerator() | |||
| public Tensors(NDArray nd): base(ops.convert_to_tensor(nd)) | |||
| { | |||
| foreach (var tensor in items) | |||
| yield return tensor; | |||
| } | |||
| /// <summary> | |||
| /// Get the element in shallow level. For example, for ts = [1, [2, 3], 4], | |||
| /// common indexer has ts[1] = 2. Shallow indexer has ts[1] = [2, 3] | |||
| /// </summary> | |||
| /// <param name="index"></param> | |||
| /// <returns></returns> | |||
| public Tensors GetShallow(int index) | |||
| { | |||
| if(NestType == NestType.Node) | |||
| { | |||
| if(index > 0) | |||
| { | |||
| throw new IndexOutOfRangeException(); | |||
| } | |||
| return this; | |||
| } | |||
| else if(NestType == NestType.List) | |||
| { | |||
| return ListValue![index].AsNest().ToTensors(); | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| private static Nest<Tensor> DealWithConstructorArrayInput(Tensor[] tensors) | |||
| { | |||
| if (tensors.Length == 0) | |||
| { | |||
| return Nest<Tensor>.Empty; | |||
| } | |||
| else if(tensors.Length == 1) | |||
| { | |||
| return new Nest<Tensor>(tensors[0]); | |||
| } | |||
| else | |||
| { | |||
| return new Nest<Tensor>(tensors.Select(x => new Nest<Tensor>(x))); | |||
| } | |||
| } | |||
| public bool IsSingle() | |||
| { | |||
| return Length == 1; | |||
| } | |||
| public new Tensors MergeWith(Nest<Tensor>? other) | |||
| { | |||
| return FromNest(base.MergeWith(other)); | |||
| } | |||
| [Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to add " + | |||
| "a tensor to `Tensors`, creating a new instance with your newly added tensor is a better choice.")] | |||
| public void Add(Tensor tensor) | |||
| => items.Add(tensor); | |||
| { | |||
| if(NestType == NestType.Dictionary) | |||
| { | |||
| throw new ValueError("Cannot add a tensor to dictionary type of nested tensors."); | |||
| } | |||
| else if(NestType == NestType.Node) | |||
| { | |||
| NestType = NestType.List; | |||
| ListValue = new() { new Nest<Tensor>(NodeValue), new Nest<Tensor>(tensor) }; | |||
| NodeValue = null; | |||
| } | |||
| else if(NestType == NestType.List) | |||
| { | |||
| ListValue!.Add(new Nest<Tensor>(tensor)); | |||
| } | |||
| else //Empty | |||
| { | |||
| NestType = NestType.Node; | |||
| NodeValue = tensor; | |||
| } | |||
| } | |||
| [Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to add " + | |||
| "some tensors to `Tensors`, creating a new instance with your newly added tensors is a better choice.")] | |||
| public void AddRange(IEnumerable<Tensor> tensors) | |||
| => items.AddRange(tensors); | |||
| { | |||
| if (NestType == NestType.Dictionary) | |||
| { | |||
| throw new ValueError("Cannot add a tensor to dictionary type of nested tensors."); | |||
| } | |||
| else if (NestType == NestType.Node) | |||
| { | |||
| NestType = NestType.List; | |||
| ListValue = new() { new Nest<Tensor>(NodeValue) }; | |||
| ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x))); | |||
| NodeValue = null; | |||
| } | |||
| else if(NestType == NestType.List) | |||
| { | |||
| ListValue!.AddRange(tensors.Select(x => new Nest<Tensor>(x))); | |||
| } | |||
| else // empty | |||
| { | |||
| NestType = NestType.List; | |||
| ListValue = tensors.Select(x => new Nest<Tensor>(x) as INestStructure<Tensor>).ToList(); | |||
| } | |||
| } | |||
| [Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to insert " + | |||
| "a tensor to `Tensors`, creating a new instance with your newly added tensor is a better choice.")] | |||
| public void Insert(int index, Tensor tensor) | |||
| => items.Insert(index, tensor); | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| => GetEnumerator(); | |||
| { | |||
| if (NestType == NestType.List) | |||
| { | |||
| ListValue.Insert(index, new Nest<Tensor>(tensor)); | |||
| } | |||
| else if(NestType == NestType.Node) | |||
| { | |||
| NestType = NestType.List; | |||
| ListValue = new() { new Nest<Tensor>(NodeValue) }; | |||
| ListValue.Insert(index, new Nest<Tensor>(tensor)); | |||
| NodeValue = null; | |||
| } | |||
| else | |||
| { | |||
| throw new ValueError("Cannot add a tensor to dictionary type of nested tensors."); | |||
| } | |||
| } | |||
| public string[] StringData() | |||
| { | |||
| EnsureSingleTensor(this, "nnumpy"); | |||
| return this[0].StringData(); | |||
| return Single.StringData(); | |||
| } | |||
| public string StringData(int index) | |||
| { | |||
| EnsureSingleTensor(this, "nnumpy"); | |||
| return this[0].StringData(index); | |||
| return Single.StringData(index); | |||
| } | |||
| public NDArray numpy() | |||
| { | |||
| EnsureSingleTensor(this, "nnumpy"); | |||
| return this[0].numpy(); | |||
| return Single.numpy(); | |||
| } | |||
| [Obsolete] | |||
| public T[] ToArray<T>() where T: unmanaged | |||
| { | |||
| EnsureSingleTensor(this, $"ToArray<{typeof(T)}>"); | |||
| return this[0].ToArray<T>(); | |||
| return Single.ToArray<T>(); | |||
| } | |||
| #region Explicit Conversions | |||
| public static explicit operator bool(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to bool"); | |||
| return (bool)tensor[0]; | |||
| return (bool)tensor.Single; | |||
| } | |||
| public static explicit operator sbyte(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to sbyte"); | |||
| return (sbyte)tensor[0]; | |||
| return (sbyte)tensor.Single; | |||
| } | |||
| public static explicit operator byte(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to byte"); | |||
| return (byte)tensor[0]; | |||
| return (byte)tensor.Single; | |||
| } | |||
| public static explicit operator ushort(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to ushort"); | |||
| return (ushort)tensor[0]; | |||
| return (ushort)tensor.Single; | |||
| } | |||
| public static explicit operator short(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to short"); | |||
| return (short)tensor[0]; | |||
| return (short)tensor.Single; | |||
| } | |||
| public static explicit operator int(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to int"); | |||
| return (int)tensor[0]; | |||
| return (int)tensor.Single; | |||
| } | |||
| public static explicit operator uint(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to uint"); | |||
| return (uint)tensor[0]; | |||
| return (uint)tensor.Single; | |||
| } | |||
| public static explicit operator long(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to long"); | |||
| return (long)tensor[0]; | |||
| return (long)tensor.Single; | |||
| } | |||
| public static explicit operator ulong(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to ulong"); | |||
| return (ulong)tensor[0]; | |||
| return (ulong)tensor.Single; | |||
| } | |||
| public static explicit operator float(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to byte"); | |||
| return (byte)tensor[0]; | |||
| return (byte)tensor.Single; | |||
| } | |||
| public static explicit operator double(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to double"); | |||
| return (double)tensor[0]; | |||
| return (double)tensor.Single; | |||
| } | |||
| public static explicit operator string(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to string"); | |||
| return (string)tensor[0]; | |||
| return (string)tensor.Single; | |||
| } | |||
| public static explicit operator object[](Tensors tensors) | |||
| => tensors.items.ToArray(); | |||
| => tensors.Flatten().ToArray(); | |||
| #endregion | |||
| #region Implicit Conversions | |||
| @@ -183,56 +307,44 @@ namespace Tensorflow | |||
| public static implicit operator Tensors(List<Tensor> tensors) | |||
| => new Tensors(tensors.ToArray()); | |||
| public static implicit operator Tensor(Tensors tensors) | |||
| => tensors.FirstOrDefault(); | |||
| public static implicit operator Tensor(Tensors? tensors) | |||
| => tensors?.SingleOrNull; | |||
| public static implicit operator Tensor[](Tensors tensors) | |||
| => tensors.items.ToArray(); | |||
| => tensors.Flatten().ToArray(); | |||
| #endregion | |||
| public void Deconstruct(out Tensor a, out Tensor b) | |||
| public static Tensors? FromNest(Nest<Tensor> nested) | |||
| { | |||
| a = items[0]; | |||
| b = items[1]; | |||
| if(nested == Nest<Tensor>.Empty) | |||
| { | |||
| return null; | |||
| } | |||
| return new Tensors(nested); | |||
| } | |||
| private static void EnsureSingleTensor(Tensors tensors, string methodnName) | |||
| public void Deconstruct(out Tensor a, out Tensors? b) | |||
| { | |||
| if(tensors.Length == 0) | |||
| { | |||
| throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains no Tensor."); | |||
| } | |||
| else if(tensors.Length > 1) | |||
| { | |||
| throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains more than one Tensor."); | |||
| } | |||
| a = this.First(); | |||
| b = Length == 1? null : new Tensors(this.Skip(1).ToArray()); | |||
| } | |||
| public override string ToString() | |||
| { | |||
| if(items.Count == 1) | |||
| if(Length == 1) | |||
| { | |||
| return items[0].ToString(); | |||
| return this.First().ToString(); | |||
| } | |||
| else | |||
| { | |||
| StringBuilder sb = new StringBuilder(); | |||
| sb.Append($"Totally {items.Count} tensors, which are {string.Join(", ", items.Select(x => x.name))}\n[\n"); | |||
| for(int i = 0; i < items.Count; i++) | |||
| { | |||
| var tensor = items[i]; | |||
| sb.Append($"Tensor {i}({tensor.name}): {tensor.ToString()}\n"); | |||
| } | |||
| sb.Append("]\n"); | |||
| return sb.ToString(); | |||
| return $"Totally {Length} tensors: {base.ToString()}"; | |||
| } | |||
| } | |||
| public void Dispose() | |||
| { | |||
| foreach (var item in items) | |||
| item.Dispose(); | |||
| foreach (var tensor in this) | |||
| tensor.Dispose(); | |||
| } | |||
| } | |||
| } | |||
| @@ -179,8 +179,7 @@ namespace Tensorflow.Train | |||
| // handles slot variables. | |||
| if (!args.Overwrite || new_variable is RefVariable || new_variable is Trackable) | |||
| { | |||
| var temp = new_variable as Trackable; | |||
| var res = _track_trackable(temp, args.Name, args.Overwrite); | |||
| var res = _track_trackable(new_variable as Trackable, args.Name, args.Overwrite); | |||
| Debug.Assert(res is IVariableV1); | |||
| return res as IVariableV1; | |||
| } | |||
| @@ -36,6 +36,7 @@ namespace Tensorflow.Util | |||
| // (np.array([3, 4]), tf.constant([3, 4])))` | |||
| // | |||
| [Obsolete] | |||
| public static class nest | |||
| { | |||
| @@ -170,11 +170,28 @@ namespace Tensorflow | |||
| public Tensor value() | |||
| => GraphElement ?? _read_variable_op(); | |||
| protected Tensor _read_variable_op() | |||
| protected Tensor _read_variable_op(bool no_copy = false) | |||
| { | |||
| variable_accessed(this); | |||
| var result = gen_resource_variable_ops.read_variable_op(handle, _dtype); | |||
| resource_variable_ops._maybe_set_handle_data(_dtype, handle, result); | |||
| Tensor read_and_set_handle(bool no_copy) | |||
| { | |||
| if (no_copy) | |||
| { | |||
| gen_resource_variable_ops.disable_copy_on_read(handle); | |||
| } | |||
| var result = gen_resource_variable_ops.read_variable_op(handle, _dtype); | |||
| resource_variable_ops._maybe_set_handle_data(_dtype, handle, result); | |||
| return result; | |||
| } | |||
| // TODO(Rinne): deal with caching device. | |||
| var result = read_and_set_handle(no_copy); | |||
| if (!tf.Context.executing_eagerly()) | |||
| { | |||
| tf.Runner.TFE_TapeSetRecordOperation("ReadVariableOp", new Tensor[] { result }, new Tensor[] { handle }, | |||
| backward_function: (x, _) => x); | |||
| } | |||
| // have to set shape when converting to substituent placeholder | |||
| if (result.shape.ndim == -1) | |||
| @@ -576,7 +576,7 @@ namespace Tensorflow | |||
| public static HandleData get_resource_handle_data(Tensor graph_op) | |||
| { | |||
| var handle_data = c_api.TFC_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); | |||
| return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data))); | |||
| return HandleData.Parser.ParseFrom(c_api.ByteStringPiece(handle_data)); | |||
| } | |||
| public static void dismantle_graph(Graph graph) | |||
| @@ -20,8 +20,12 @@ using System.Linq; | |||
| using System.Collections.Generic; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.Graphs; | |||
| using Tensorflow.Common.Extensions; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.Graphs.SubGraphUtility; | |||
| using Tensorflow.Util; | |||
| using Tensorflow.Common.Types; | |||
| using System.Diagnostics; | |||
| namespace Tensorflow.Keras | |||
| { | |||
| @@ -450,5 +454,526 @@ namespace Tensorflow.Keras | |||
| return x; | |||
| } | |||
| public (Tensors, Tensors, Tensors) rnn( | |||
| Func<Tensors, Tensors, (Tensors, Tensors)> step_function, // args:inputs, states, return:output, new_states | |||
| Tensors inputs, // inputs is a tuple of tensors (one per input sequence) | |||
| Tensors initial_states, | |||
| bool go_backwards = false, | |||
| Tensor? mask = null, | |||
| Tensors? constants = null, | |||
| bool unroll = false, | |||
| Tensors? input_length = null, // An integer or a 1-D Tensor,depending on whether the time dimension is fixed-length or not | |||
| bool time_major = false, | |||
| bool zero_output_for_mask = false, | |||
| bool return_all_outputs = true) | |||
| { | |||
| Tensor swap_batch_timestep(Tensor input_t) | |||
| { | |||
| var axes = Enumerable.Range(0, input_t.rank).ToArray(); | |||
| axes[0] = 1; | |||
| axes[1] = 0; | |||
| return tf.transpose(input_t, axes); | |||
| } | |||
| if (!time_major) | |||
| { | |||
| inputs = Nest.MapStructure(swap_batch_timestep, inputs).ToTensors(); | |||
| } | |||
| var flatted_inptus = Nest.Flatten(inputs).ToList(); | |||
| var first_flatted_input = flatted_inptus[0]; | |||
| var time_steps = first_flatted_input.shape[0]; | |||
| var batch = first_flatted_input.shape[1]; | |||
| var time_steps_t = tf.shape(first_flatted_input)[0]; | |||
| foreach (var input_ in flatted_inptus) | |||
| { | |||
| input_.shape.with_rank_at_least(3); | |||
| } | |||
| if (mask != null) | |||
| { | |||
| if (mask.dtype != TF_DataType.TF_BOOL) | |||
| { | |||
| mask = tf.cast(mask, TF_DataType.TF_BOOL); | |||
| } | |||
| if (mask.rank == 2) | |||
| { | |||
| mask = tf.expand_dims(mask, -1); | |||
| } | |||
| if (!time_major) | |||
| { | |||
| mask = swap_batch_timestep(mask); | |||
| } | |||
| } | |||
| // tf.where needs its condition tensor to be the same shape as its two | |||
| // result tensors, but in our case the condition (mask) tensor is | |||
| // (nsamples, 1), and inputs are (nsamples, ndimensions) or even more. | |||
| // So we need to broadcast the mask to match the shape of inputs. | |||
| // That's what the tile call does, it just repeats the mask along its | |||
| // second dimension n times. | |||
| Tensors _expand_mask(Tensors mask_t, Tensors input_t, int fixed_dim = 1) | |||
| { | |||
| if (!mask_t.IsSingle()) | |||
| { | |||
| throw new ValueError($"mask_t is expected to be tensor, but got {mask_t}"); | |||
| } | |||
| if (!input_t.IsSingle()) | |||
| { | |||
| throw new ValueError($"input_t is expected to be tensor, but got {input_t}"); | |||
| } | |||
| var rank_diff = input_t.rank - mask_t.rank; | |||
| for (int i = 0; i < rank_diff; i++) | |||
| { | |||
| mask_t = tf.expand_dims(mask_t, -1); | |||
| } | |||
| var multiples = Enumerable.Repeat(1, fixed_dim).ToArray().concat(input_t.shape.as_int_list().Skip(fixed_dim).ToArray()); | |||
| return tf.tile(mask_t, multiples); | |||
| } | |||
| Tensors outputs = new Tensors(); | |||
| Tensors output_time_zero = new Tensors(); | |||
| Tensors last_output = new Tensors(); | |||
| Tensors new_states = new Tensors(); | |||
| if (unroll) | |||
| { | |||
| if (time_steps == 0) | |||
| { | |||
| throw new ValueError("Unrolling requires a fixed number of timesteps."); | |||
| } | |||
| // Process the input tensors. The input tensor need to be split on the | |||
| // time_step dim, and reverse if go_backwards is True. In the case of | |||
| // nested input, the input is flattened and then transformed | |||
| // individually. The result of this will be a tuple of lists, each of | |||
| // the item in tuple is list of the tensor with shape (batch, feature) | |||
| // TODO(Wanglongzhi2001),step_func接受的第二个参数为List,但是最后却用的tuple | |||
| //var states = Tuple.Create(initial_states); | |||
| var states = initial_states; | |||
| var successive_states = new Tensors(); | |||
| var successive_outputs = new Tensors(); | |||
| // Process the input tensors. The input tensor need to be split on the | |||
| // time_step dim, and reverse if go_backwards is True. In the case of | |||
| // nested input, the input is flattened and then transformed | |||
| // individually. The result of this will be a tuple of lists, each of | |||
| // the item in tuple is list of the tensor with shape (batch, feature) | |||
| Tensors _process_single_input_t(Tensor input_t) | |||
| { | |||
| var unstaked_input_t = array_ops.unstack(input_t); // unstack for time_step dim | |||
| if (go_backwards) | |||
| { | |||
| unstaked_input_t = unstaked_input_t.Reverse().ToArray(); | |||
| } | |||
| return unstaked_input_t; | |||
| } | |||
| // TODO(Wanglongzhi2001) | |||
| Tensors processed_input; | |||
| if (!inputs.IsSingle()) | |||
| { | |||
| processed_input = inputs.MapStructure(_process_single_input_t).ReduceTo<Tensors, Tensor>().ToTensors(); | |||
| } | |||
| else | |||
| { | |||
| processed_input = _process_single_input_t(inputs); | |||
| } | |||
| object _get_input_tensor(int time) | |||
| { | |||
| List<Tensor> inp = new List<Tensor>(); | |||
| foreach (var t_ in processed_input) | |||
| { | |||
| inp.Add(t_[time]); | |||
| } | |||
| return Nest.PackSequenceAs(inputs, inp); | |||
| } | |||
| if (mask != null) | |||
| { | |||
| var mask_list = tf.unstack(mask); | |||
| if (go_backwards) | |||
| { | |||
| mask_list.Reverse().ToArray(); | |||
| } | |||
| for (int i = 0; i < time_steps; i++) | |||
| { | |||
| // TODO(Wanglongzhi2001),deal with _get_input_tensor | |||
| var inp = _get_input_tensor(i); | |||
| var mask_t = mask_list[i]; | |||
| // TODO | |||
| var (output, newStates) = step_function((Tensors)inp, states.MergeWith(constants)); | |||
| var tiled_mask_t = _expand_mask(mask_t, output); | |||
| Tensors prev_output; | |||
| if (successive_outputs == null) | |||
| { | |||
| prev_output = tf.zeros_like(output); | |||
| } | |||
| else | |||
| { | |||
| prev_output = successive_outputs.Last(); | |||
| } | |||
| // output could be a tensor | |||
| output = tf.where(tiled_mask_t, output, prev_output); | |||
| var flat_states = Nest.Flatten(states).ToList(); | |||
| var flat_new_states = Nest.Flatten(newStates).ToList(); | |||
| var tiledMaskT = flat_states | |||
| .Select(s => _expand_mask(mask_t, s)) | |||
| .ToArray(); | |||
| var tuple = Tuple.Create(tiledMaskT); | |||
| List<Tensor> flat_final_states = new List<Tensor>(); | |||
| foreach (var (m, s, ps) in zip(tiled_mask_t.ToList(), flat_new_states, flat_states)) | |||
| { | |||
| flat_final_states.Add(tf.where(m, s, ps)); | |||
| } | |||
| states = Nest.PackSequenceAs(states, flat_final_states).ToTensors(); | |||
| if (return_all_outputs) | |||
| { | |||
| successive_outputs = successive_outputs.MergeWith(output); | |||
| successive_outputs = successive_states.MergeWith(states); | |||
| } | |||
| else | |||
| { | |||
| successive_outputs = new Tensors(output); | |||
| successive_states = new Tensors(states); | |||
| } | |||
| } | |||
| last_output = successive_outputs.Last(); | |||
| new_states = successive_states.Last(); | |||
| outputs = tf.stack(successive_outputs); | |||
| if (zero_output_for_mask) | |||
| { | |||
| last_output = tf.where(_expand_mask(mask_list.Last(), last_output), last_output, tf.zeros_like(last_output)); | |||
| outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs)); | |||
| } | |||
| else // mask is null | |||
| { | |||
| for (int i = 0; i < time_steps; i++) | |||
| { | |||
| var inp = _get_input_tensor(i); | |||
| var (output, newStates) = step_function((Tensors)inp, states.MergeWith(constants)); | |||
| states = newStates; | |||
| if (return_all_outputs) | |||
| { | |||
| successive_outputs.Add(output); | |||
| successive_states.Add(newStates); | |||
| } | |||
| else | |||
| { | |||
| successive_outputs = new Tensors { output }; | |||
| successive_states = new Tensors { newStates }; | |||
| } | |||
| } | |||
| last_output = successive_outputs.Last(); | |||
| new_states = successive_states.Last(); | |||
| outputs = tf.stack(successive_outputs); | |||
| } | |||
| } | |||
| } | |||
| else // unroll == false | |||
| { | |||
| var states = initial_states; | |||
| // Create input tensor array, if the inputs is nested tensors, then it | |||
| // will be flattened first, and tensor array will be created one per | |||
| // flattened tensor. | |||
| var input_ta = new List<TensorArray>(); | |||
| for (int i = 0; i < flatted_inptus.Count; i++) | |||
| { | |||
| input_ta.Add(TensorArray.Create(dtype: flatted_inptus[i].dtype, size: time_steps_t)); | |||
| } | |||
| foreach(var (ta, input_) in zip(input_ta, flatted_inptus)) | |||
| { | |||
| if (!go_backwards) | |||
| { | |||
| ta.unstack(input_); | |||
| } | |||
| else | |||
| { | |||
| ta.unstack(reverse(input_, 0)); | |||
| } | |||
| } | |||
| // Get the time(0) input and compute the output for that, the output will | |||
| // be used to determine the dtype of output tensor array. Don't read from | |||
| // input_ta due to TensorArray clear_after_read default to True. | |||
| var input_time_zero = Nest.PackSequenceAs(inputs, flatted_inptus.Select(x => x[0]).ToArray()).ToTensors(); | |||
| // output_time_zero is used to determine the cell output shape and its | |||
| // dtype. the value is discarded. | |||
| (output_time_zero, _) = step_function(input_time_zero, | |||
| constants is null ? initial_states : initial_states.MergeWith(constants)); | |||
| Tensor output_ta_size = return_all_outputs ? time_steps_t : constant_op.constant(1); | |||
| var output_ta = new List<TensorArray>(); | |||
| foreach(var output in output_time_zero.Flatten()) | |||
| { | |||
| output_ta.Add(TensorArray.Create(dtype: output.dtype, size: output_ta_size, element_shape: output.shape)); | |||
| } | |||
| var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time"); | |||
| Func<Tensor, Tensor>? masking_fn; | |||
| Func<Tensors, Tensors, Tensors, Tensors>? compute_masked_output = null; | |||
| if (mask != null) | |||
| { | |||
| if (go_backwards) | |||
| { | |||
| mask = tf.reverse(mask, axis: new[] { 0 }); | |||
| } | |||
| var mask_ta = TensorArray.Create(dtype: TF_DataType.TF_BOOL, size: time_steps_t); | |||
| mask_ta = mask_ta.unstack(mask); | |||
| masking_fn = (time) => | |||
| { | |||
| return mask_ta.read(time); | |||
| }; | |||
| compute_masked_output = (mask_t, flat_out, flat_mask) => | |||
| { | |||
| var tiled_mask_t = new Tensors(); | |||
| foreach (var o in flat_out) | |||
| { | |||
| tiled_mask_t.Add(_expand_mask(mask_t, o, fixed_dim: mask_t.rank)); | |||
| } | |||
| Tensors res = new Tensors(); | |||
| foreach (var (m, o, fm) in zip(tiled_mask_t.ToList(), flat_out.ToList(), flat_mask.ToList())) | |||
| { | |||
| res.Add(tf.where(m, o, fm)); | |||
| } | |||
| return res; | |||
| }; | |||
| } | |||
| // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor), it could be an integer or tensor | |||
| else if (input_length is Tensor) | |||
| { | |||
| if (go_backwards) | |||
| { | |||
| var max_len = tf.reduce_max(input_length, axis: 0); | |||
| var rev_input_length = tf.subtract(max_len - 1, input_length); | |||
| masking_fn = (time) => | |||
| { | |||
| return tf.less(rev_input_length, time); | |||
| }; | |||
| } | |||
| else | |||
| { | |||
| masking_fn = (time) => | |||
| { | |||
| return tf.greater(input_length, time); | |||
| }; | |||
| } | |||
| compute_masked_output = (mask_t, flat_out, flat_mask) => | |||
| { | |||
| var res = new List<Tensor>(); | |||
| foreach (var (o, zo) in zip(flat_out, flat_mask)) | |||
| { | |||
| res.Add(tf.where(mask_t, o, zo)); | |||
| } | |||
| return res; | |||
| }; | |||
| } | |||
| else | |||
| { | |||
| masking_fn = null; | |||
| } | |||
| Func<Tensors, Tensor> cond = (time) => (time[0] < time_steps_t); | |||
| int parallel_iterations = 32; | |||
| Tensors final_outputs; | |||
| if (masking_fn != null) | |||
| { | |||
| // Mask for the T output will be base on the output of T - 1. In the | |||
| // case T = 0, a zero filled tensor will be used. | |||
| var flat_zero_output = new Tensors(); | |||
| foreach (var o in Nest.Flatten(output_time_zero)) | |||
| { | |||
| flat_zero_output.Add(tf.zeros_like(o)); | |||
| } | |||
| var prev_output = flat_zero_output; | |||
| var output_ta_t = output_ta; | |||
| Tensors _step(Tensors tensors) | |||
| { | |||
| /* | |||
| RNN step function. | |||
| Args: | |||
| time: Current timestep value. | |||
| output_ta_t: TensorArray. | |||
| prev_output: tuple of outputs from time - 1. | |||
| *states: List of states. | |||
| Returns: | |||
| Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)` | |||
| */ | |||
| Tensor time = tensors[0]; | |||
| TensorArray output_ta_t = (tensors[1] as FakeTensorByTensorArray).TensorArray; | |||
| Tensors prev_output = tensors.GetShallow(2); | |||
| Tensors states = new Tensors(tensors.Skip(2 + prev_output.Length).ToArray()); | |||
| var flat_current_input = input_ta.Select(x => x.read(time)).ToList(); | |||
| // maybe set shape | |||
| // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type | |||
| var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors(); | |||
| var mask_t = masking_fn(time); | |||
| var (output, new_states) = step_function(current_input, states.MergeWith(constants)); | |||
| // mask output | |||
| var flat_output = Nest.Flatten(output).ToList(); | |||
| var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.Flatten().ToList(); | |||
| // TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type | |||
| var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output); | |||
| // mask states | |||
| var flat_state = states.Flatten().ToList(); | |||
| var flat_new_state = new_states.Flatten().ToList(); | |||
| foreach (var (state, new_state) in zip(flat_state, flat_new_state)) | |||
| { | |||
| if (new_state is Tensor) | |||
| { | |||
| new_state.shape = state.shape; | |||
| } | |||
| } | |||
| var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state); | |||
| new_states = Nest.PackSequenceAs(new_states, flat_final_state.ToArray()).ToTensors(); | |||
| var ta_index_to_write = return_all_outputs ? time : tf.constant(0); | |||
| Debug.Assert(flat_output.Count() == 1); | |||
| output_ta_t = output_ta_t.write(ta_index_to_write, flat_new_output.First()); | |||
| return new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta_t) }.Concat(flat_new_output).Concat(new_states) | |||
| .ToArray().ToTensors(); | |||
| } | |||
| var loop_vars = new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta[0]) } | |||
| .Concat(flat_zero_output.Flatten()).Concat(states).ToArray().ToTensors(); | |||
| final_outputs = control_flow_ops.while_loop(cond: cond, body: _step, loop_vars: loop_vars, parallel_iterations: parallel_iterations); | |||
| new_states = final_outputs.Skip(3).ToList(); | |||
| } | |||
| else | |||
| { | |||
| var output_ta_t = output_ta; | |||
| new_states = states; | |||
| Tensors _step(Tensors tensors) | |||
| { | |||
| Tensor time = tensors[0]; | |||
| TensorArray output_ta_t = (tensors[1] as FakeTensorByTensorArray).TensorArray; | |||
| Tensors states = new Tensors(tensors.Skip(2).ToArray()); | |||
| var flat_current_input = input_ta.Select(x => x.read(time)).ToList(); | |||
| // maybe set shape | |||
| // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type | |||
| var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors(); | |||
| var (output, new_states) = step_function(current_input, states.MergeWith(constants)); | |||
| var flat_state = new_states.Flatten().ToList(); | |||
| var flat_new_state = new_states.Flatten().ToList(); | |||
| foreach (var (state, new_state) in zip(flat_state, flat_new_state)) | |||
| { | |||
| if (new_state is Tensor) | |||
| { | |||
| new_state.shape = state.shape; | |||
| } | |||
| } | |||
| var flat_output = Nest.Flatten(output); | |||
| var ta_index_to_write = return_all_outputs ? time : tf.constant(0); | |||
| Debug.Assert(flat_output.Count() == 1); | |||
| output_ta_t = output_ta_t.write(ta_index_to_write, flat_output.First()); | |||
| new_states = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors(); | |||
| return new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta_t) }.Concat(new_states).ToArray().ToTensors(); | |||
| } | |||
| Debug.Assert(output_ta.Count == 1); | |||
| var loop_vars = new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta[0]) }.Concat(states).ToArray().ToTensors(); | |||
| final_outputs = control_flow_ops.while_loop(cond: cond, body: _step, loop_vars: loop_vars, parallel_iterations: parallel_iterations); | |||
| new_states = final_outputs.Skip(2).ToList(); | |||
| } | |||
| output_ta = new List<TensorArray> { (final_outputs[1] as FakeTensorByTensorArray).TensorArray }; | |||
| outputs = outputs.MergeWith(output_ta.Select(o => o.stack()).ToArray().ToTensors()); | |||
| last_output = last_output.MergeWith(outputs.Select(o => o[-1]).ToArray().ToTensors()); | |||
| outputs = Nest.PackSequenceAs(output_time_zero, (Tensor[])outputs).ToTensors(); | |||
| last_output = Nest.PackSequenceAs(output_time_zero, (Tensor[])last_output).ToTensors(); | |||
| } | |||
| Func<Tensor, Tensor> set_shape; | |||
| set_shape = (output_) => | |||
| { | |||
| if (output_ is Tensor) | |||
| { | |||
| var shape = output_.shape.as_int_list(); | |||
| if (return_all_outputs) | |||
| { | |||
| shape[0] = (int)time_steps; | |||
| } | |||
| else | |||
| { | |||
| shape[0] = 1; | |||
| } | |||
| shape[1] = (int)batch; | |||
| output_.shape = shape; | |||
| } | |||
| return output_; | |||
| }; | |||
| outputs = Nest.MapStructure(set_shape, outputs).ToTensors(); | |||
| if (!time_major) | |||
| { | |||
| outputs = Nest.MapStructure(swap_batch_timestep, outputs).ToTensors(); | |||
| } | |||
| return (last_output, outputs, new_states); | |||
| } | |||
| public Tensor reverse(Tensor input, int axis) | |||
| { | |||
| return reverse(input, new int[] { axis }); | |||
| } | |||
| public Tensor reverse(Tensor input, int[] axes) | |||
| { | |||
| return tf.reverse(input, axes); | |||
| } | |||
| public Tensor maybe_convert_to_ragged(bool is_ragged_output, Tensor output, int nested_row_lengths, bool go_backwards = false) | |||
| { | |||
| if (!is_ragged_output) | |||
| { | |||
| return output; | |||
| } | |||
| throw new NotImplementedException("Not implemented currently, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,6 +1,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Saving.SavedModel; | |||
| using Tensorflow.Keras.Utils; | |||
| @@ -81,7 +82,7 @@ namespace Tensorflow.Keras.Engine | |||
| } | |||
| else | |||
| { | |||
| _buildInputShape = new Saving.TensorShapeConfig(); | |||
| _buildInputShape = new TensorShapeConfig(); | |||
| } | |||
| if (outputs.Any(x => x.KerasHistory == null)) | |||
| @@ -325,7 +326,7 @@ namespace Tensorflow.Keras.Engine | |||
| nodes_in_decreasing_depth.append(node); | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| var tensor_dict = new Dictionary<long, Queue<Tensor>>(); | |||
| // map input values | |||
| @@ -1,4 +1,5 @@ | |||
| using System.Threading; | |||
| using Tensorflow.Common.Types; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.Engine | |||
| @@ -8,11 +9,11 @@ namespace Tensorflow.Keras.Engine | |||
| /// <summary> | |||
| /// Wraps `call`, applying pre- and post-processing steps. | |||
| /// </summary> | |||
| /// <param name="input"></param> | |||
| /// <param name="inputs"></param> | |||
| /// <param name="state"></param> | |||
| /// <param name="training"></param> | |||
| /// <returns></returns> | |||
| public Tensors Apply(Tensors inputs, Tensor state = null, bool training = false) | |||
| public virtual Tensors Apply(Tensors inputs, Tensors states = null, bool training = false, IOptionalArgs? optional_args = null) | |||
| { | |||
| if (callContext.Value == null) | |||
| callContext.Value = new CallContext(); | |||
| @@ -30,13 +31,15 @@ namespace Tensorflow.Keras.Engine | |||
| if (!built) | |||
| MaybeBuild(inputs); | |||
| var outputs = Call(inputs, state: state, training: training); | |||
| var outputs = Call(inputs, state: states, training: training); | |||
| // memory leak | |||
| // _set_connectivity_metadata_(inputs, outputs); | |||
| _handle_activity_regularization(inputs, outputs); | |||
| _set_mask_metadata(inputs, outputs, null); | |||
| // TODO(Rinne): set save spec if null | |||
| scope.__exit__(); | |||
| return outputs; | |||
| @@ -32,7 +32,7 @@ using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow.Framework; | |||
| using Tensorflow.Sessions; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Engine | |||
| { | |||
| @@ -332,7 +332,7 @@ namespace Tensorflow.Keras.Engine | |||
| /// <param name="state"></param> | |||
| /// <param name="training"></param> | |||
| /// <returns></returns> | |||
| protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected virtual Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| if(ReplacedCall is not null) | |||
| { | |||
| @@ -23,7 +23,7 @@ namespace Tensorflow.Keras.Engine | |||
| var graph = tf.executing_eagerly() ? new FuncGraph("build_graph") : keras.backend.get_graph(); | |||
| graph.as_default(); | |||
| var shapes = input_shape.ToShapeArray(); | |||
| var x = new Tensors(shapes.Select(x => base_layer_utils.generate_placeholders_from_shape(x))); | |||
| var x = new Tensors(shapes.Select(x => base_layer_utils.generate_placeholders_from_shape(x)).ToArray()); | |||
| try | |||
| { | |||
| Call(x, training: false); | |||
| @@ -72,7 +72,7 @@ namespace Tensorflow.Keras.Engine | |||
| { | |||
| var data_handler = new DataHandler(new DataHandlerArgs | |||
| { | |||
| X = new Tensors(x), | |||
| X = new Tensors(x.ToArray()), | |||
| Y = y, | |||
| Model = this, | |||
| StepsPerExecution = _steps_per_execution | |||
| @@ -168,7 +168,8 @@ namespace Tensorflow.Keras.Engine | |||
| Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handler, Tensor[] data) | |||
| { | |||
| var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; | |||
| var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size))); | |||
| var outputs = train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())); | |||
| tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | |||
| return outputs; | |||
| } | |||
| } | |||
| @@ -110,7 +110,7 @@ namespace Tensorflow.Keras.Engine | |||
| var data_handler = new DataHandler(new DataHandlerArgs | |||
| { | |||
| X = new Tensors(train_x), | |||
| X = new Tensors(train_x.ToArray()), | |||
| Y = train_y, | |||
| BatchSize = batch_size, | |||
| InitialEpoch = initial_epoch, | |||
| @@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine | |||
| { | |||
| var data = iterator.next(); | |||
| var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; | |||
| var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size))); | |||
| var outputs = train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())); | |||
| tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | |||
| return outputs; | |||
| } | |||
| @@ -1,8 +1,8 @@ | |||
| using System.Diagnostics; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Framework.Models; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Losses; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.Keras.Saving.SavedModel; | |||
| using Tensorflow.Keras.Utils; | |||
| using Tensorflow.Train; | |||
| @@ -21,6 +21,7 @@ using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Layers; | |||
| using Tensorflow.Keras.Utils; | |||
| using static Tensorflow.KerasApi; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Engine | |||
| { | |||
| @@ -143,7 +144,7 @@ namespace Tensorflow.Keras.Engine | |||
| } | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| if (!_has_explicit_input_shape) | |||
| { | |||
| @@ -0,0 +1,4 @@ | |||
| namespace System.Runtime.CompilerServices | |||
| { | |||
| internal static class IsExternalInit { } | |||
| } | |||
| @@ -1,6 +1,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving; | |||
| @@ -29,7 +30,7 @@ namespace Tensorflow.Keras.Layers { | |||
| base.build(input_shape); | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| Tensor output = inputs; | |||
| output = tf.where(output > 0f, output, | |||
| @@ -4,7 +4,7 @@ using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Layers { | |||
| public class Exponential : Layer | |||
| @@ -17,7 +17,7 @@ namespace Tensorflow.Keras.Layers { | |||
| { | |||
| base.build(input_shape); | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| Tensor output = inputs; | |||
| return tf.exp(output); | |||
| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Common.Types; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.Layers { | |||
| @@ -10,7 +11,7 @@ namespace Tensorflow.Keras.Layers { | |||
| public HardSigmoid ( LayerArgs args ) : base(args) { | |||
| // hard sigmoid has no arguments | |||
| } | |||
| protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | |||
| protected override Tensors Call ( Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null ) { | |||
| Tensor x = inputs; | |||
| return tf.clip_by_value( | |||
| tf.add(tf.multiply(x, 0.2f), 0.5f), 0f, 1f); | |||
| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Common.Types; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.Layers | |||
| @@ -19,7 +20,7 @@ namespace Tensorflow.Keras.Layers | |||
| this.args = args; | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| return tf.nn.leaky_relu(inputs, alpha: alpha); | |||
| } | |||
| @@ -1,6 +1,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving; | |||
| @@ -22,7 +23,7 @@ namespace Tensorflow.Keras.Layers { | |||
| } | |||
| base.build(input_shape); | |||
| } | |||
| protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | |||
| protected override Tensors Call ( Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { | |||
| Tensor output = inputs; | |||
| return tf.where(output > 0f, | |||
| tf.multiply(scale, output), | |||