feat: add rnn basic modulestags/v0.110.0-LSTM-Model
| @@ -3,16 +3,16 @@ using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Extensions | |||
| namespace Tensorflow.Common.Extensions | |||
| { | |||
| public static class JObjectExtensions | |||
| { | |||
| public static T? TryGetOrReturnNull<T>(this JObject obj, string key) | |||
| { | |||
| var res = obj[key]; | |||
| if(res is null) | |||
| if (res is null) | |||
| { | |||
| return default(T); | |||
| return default; | |||
| } | |||
| else | |||
| { | |||
| @@ -0,0 +1,33 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Extensions | |||
| { | |||
| public static class LinqExtensions | |||
| { | |||
| #if NETSTANDARD2_0 | |||
| public static IEnumerable<T> TakeLast<T>(this IEnumerable<T> sequence, int count) | |||
| { | |||
| return sequence.Skip(sequence.Count() - count); | |||
| } | |||
| public static IEnumerable<T> SkipLast<T>(this IEnumerable<T> sequence, int count) | |||
| { | |||
| return sequence.Take(sequence.Count() - count); | |||
| } | |||
| #endif | |||
| public static Tensors ToTensors(this IEnumerable<Tensor> tensors) | |||
| { | |||
| return new Tensors(tensors); | |||
| } | |||
| public static void Deconstruct<T1, T2, T3>(this (T1, T2, T3) values, out T1 first, out T2 second, out T3 third) | |||
| { | |||
| first = values.Item1; | |||
| second = values.Item2; | |||
| third = values.Item3; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,33 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Common.Extensions | |||
| { | |||
| public static class NestExtensions | |||
| { | |||
| public static Tensors ToTensors(this INestable<Tensor> tensors) | |||
| { | |||
| return new Tensors(tensors.AsNest()); | |||
| } | |||
| public static Tensors? ToTensors(this Nest<Tensor> tensors) | |||
| { | |||
| return Tensors.FromNest(tensors); | |||
| } | |||
| /// <summary> | |||
| /// If the nested object is already a nested type, this function could reduce it. | |||
| /// For example, `Nest[Nest[T]]` can be reduced to `Nest[T]`. | |||
| /// </summary> | |||
| /// <typeparam name="TIn"></typeparam> | |||
| /// <typeparam name="TOut"></typeparam> | |||
| /// <param name="input"></param> | |||
| /// <returns></returns> | |||
| public static Nest<TOut> ReduceTo<TIn, TOut>(this INestStructure<TIn> input) where TIn: INestStructure<TOut> | |||
| { | |||
| return Nest<TOut>.ReduceFrom(input); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,130 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| public class GeneralizedTensorShape: IEnumerable<long?[]>, INestStructure<long?>, INestable<long?> | |||
| { | |||
| public TensorShapeConfig[] Shapes { get; set; } | |||
| /// <summary> | |||
| /// create a single-dim generalized Tensor shape. | |||
| /// </summary> | |||
| /// <param name="dim"></param> | |||
| public GeneralizedTensorShape(int dim) | |||
| { | |||
| Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } }; | |||
| } | |||
| public GeneralizedTensorShape(Shape shape) | |||
| { | |||
| Shapes = new TensorShapeConfig[] { shape }; | |||
| } | |||
| public GeneralizedTensorShape(TensorShapeConfig shape) | |||
| { | |||
| Shapes = new TensorShapeConfig[] { shape }; | |||
| } | |||
| public GeneralizedTensorShape(TensorShapeConfig[] shapes) | |||
| { | |||
| Shapes = shapes; | |||
| } | |||
| public GeneralizedTensorShape(IEnumerable<Shape> shape) | |||
| { | |||
| Shapes = shape.Select(x => (TensorShapeConfig)x).ToArray(); | |||
| } | |||
| public Shape ToSingleShape() | |||
| { | |||
| if (Shapes.Length != 1) | |||
| { | |||
| throw new ValueError("The generalized shape contains more than 1 dim."); | |||
| } | |||
| var shape_config = Shapes[0]; | |||
| Debug.Assert(shape_config is not null); | |||
| return new Shape(shape_config.Items.Select(x => x is null ? -1 : x.Value).ToArray()); | |||
| } | |||
| public long ToNumber() | |||
| { | |||
| if(Shapes.Length != 1 || Shapes[0].Items.Length != 1) | |||
| { | |||
| throw new ValueError("The generalized shape contains more than 1 dim."); | |||
| } | |||
| var res = Shapes[0].Items[0]; | |||
| return res is null ? -1 : res.Value; | |||
| } | |||
| public Shape[] ToShapeArray() | |||
| { | |||
| return Shapes.Select(x => new Shape(x.Items.Select(y => y is null ? -1 : y.Value).ToArray())).ToArray(); | |||
| } | |||
| public IEnumerable<long?> Flatten() | |||
| { | |||
| List<long?> result = new List<long?>(); | |||
| foreach(var shapeConfig in Shapes) | |||
| { | |||
| result.AddRange(shapeConfig.Items); | |||
| } | |||
| return result; | |||
| } | |||
| public INestStructure<TOut> MapStructure<TOut>(Func<long?, TOut> func) | |||
| { | |||
| List<Nest<TOut>> lists = new(); | |||
| foreach(var shapeConfig in Shapes) | |||
| { | |||
| lists.Add(new Nest<TOut>(shapeConfig.Items.Select(x => new Nest<TOut>(func(x))))); | |||
| } | |||
| return new Nest<TOut>(lists); | |||
| } | |||
| public Nest<long?> AsNest() | |||
| { | |||
| Nest<long?> DealWithSingleShape(TensorShapeConfig config) | |||
| { | |||
| if (config.Items.Length == 0) | |||
| { | |||
| return Nest<long?>.Empty; | |||
| } | |||
| else if (config.Items.Length == 1) | |||
| { | |||
| return new Nest<long?>(config.Items[0]); | |||
| } | |||
| else | |||
| { | |||
| return new Nest<long?>(config.Items.Select(x => new Nest<long?>(x))); | |||
| } | |||
| } | |||
| if(Shapes.Length == 0) | |||
| { | |||
| return Nest<long?>.Empty; | |||
| } | |||
| else if(Shapes.Length == 1) | |||
| { | |||
| return DealWithSingleShape(Shapes[0]); | |||
| } | |||
| else | |||
| { | |||
| return new Nest<long?>(Shapes.Select(s => DealWithSingleShape(s))); | |||
| } | |||
| } | |||
| public IEnumerator<long?[]> GetEnumerator() | |||
| { | |||
| foreach (var shape in Shapes) | |||
| { | |||
| yield return shape.Items; | |||
| } | |||
| } | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| { | |||
| return GetEnumerator(); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,27 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| /// <summary> | |||
| /// This interface indicates that a class may have a nested structure and provide | |||
| /// methods to manipulate with the structure. | |||
| /// </summary> | |||
| public interface INestStructure<T>: INestable<T> | |||
| { | |||
| /// <summary> | |||
| /// Flatten the Nestable object. Node that if the object contains only one value, | |||
| /// it will be flattened to an enumerable with one element. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| IEnumerable<T> Flatten(); | |||
| /// <summary> | |||
| /// Construct a new object with the same nested structure. | |||
| /// </summary> | |||
| /// <typeparam name="TOut"></typeparam> | |||
| /// <param name="func"></param> | |||
| /// <returns></returns> | |||
| INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func); | |||
| } | |||
| } | |||
| @@ -0,0 +1,11 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| public interface INestable<T> | |||
| { | |||
| Nest<T> AsNest(); | |||
| } | |||
| } | |||
| @@ -0,0 +1,21 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| /// <summary> | |||
| /// This interface is used when some corresponding python methods have optional args. | |||
| /// For example, `Keras.Layer.Apply` generally takes three args as the inputs, while | |||
| /// `Keras.Layer.RNN` takes more. Then when calling RNN, you should add `RnnOptionalArgs` | |||
| /// as the parameter of the method. | |||
| /// </summary> | |||
| public interface IOptionalArgs | |||
| { | |||
| /// <summary> | |||
| /// The identifier of the class. It is not an argument but only something to | |||
| /// separate different OptionalArgs. | |||
| /// </summary> | |||
| string Identifier { get; } | |||
| } | |||
| } | |||
| @@ -0,0 +1,62 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| public static class Nest | |||
| { | |||
| /// <summary> | |||
| /// Pack the flat items to a nested sequence by the template. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| /// <param name="template"></param> | |||
| /// <param name="flatItems"></param> | |||
| /// <returns></returns> | |||
| public static Nest<T> PackSequenceAs<T>(INestable<T> template, T[] flatItems) | |||
| { | |||
| return template.AsNest().PackSequence(flatItems); | |||
| } | |||
| /// <summary> | |||
| /// Pack the flat items to a nested sequence by the template. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| /// <param name="template"></param> | |||
| /// <param name="flatItems"></param> | |||
| /// <returns></returns> | |||
| public static Nest<T> PackSequenceAs<T>(INestable<T> template, List<T> flatItems) | |||
| { | |||
| return template.AsNest().PackSequence(flatItems.ToArray()); | |||
| } | |||
| /// <summary> | |||
| /// Flatten the nested object. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| /// <param name="nestedObject"></param> | |||
| /// <returns></returns> | |||
| public static IEnumerable<T> Flatten<T>(INestable<T> nestedObject) | |||
| { | |||
| return nestedObject.AsNest().Flatten(); | |||
| } | |||
| /// <summary> | |||
| /// Map the structure with specified function. | |||
| /// </summary> | |||
| /// <typeparam name="TIn"></typeparam> | |||
| /// <typeparam name="TOut"></typeparam> | |||
| /// <param name="func"></param> | |||
| /// <param name="nestedObject"></param> | |||
| /// <returns></returns> | |||
| public static INestStructure<TOut> MapStructure<TIn, TOut>(Func<TIn, TOut> func, INestable<TIn> nestedObject) | |||
| { | |||
| return nestedObject.AsNest().MapStructure(func); | |||
| } | |||
| public static bool IsNested<T>(INestable<T> obj) | |||
| { | |||
| return obj.AsNest().IsNested(); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,458 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Common.Extensions; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| public enum NestType | |||
| { | |||
| Empty, | |||
| Node, | |||
| List, | |||
| Dictionary | |||
| } | |||
| /// <summary> | |||
| /// A nested structure which may inclulde value, list and dictionary. | |||
| /// Note that dictionary does not ensure the data order. When using it as IEnumerable, | |||
| /// its order is depth-first. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| public class Nest<T> : INestStructure<T>, IEnumerable<T> | |||
| { | |||
| private static readonly Nest<T> _empty = new Nest<T>() | |||
| { | |||
| NestType = NestType.Empty, | |||
| }; | |||
| public static Nest<T> Empty => _empty; | |||
| public NestType NestType { get; protected set; } | |||
| public string? Name { get; set; } | |||
| public T? Value { get; protected set; } | |||
| public List<Nest<T>>? ListValue { get; protected set; } | |||
| public Dictionary<string, Nest<T>>? DictValue { get; protected set; } | |||
| protected Nest() { } | |||
| public Nest(T value, string? name = null) | |||
| { | |||
| Value = value; | |||
| Name = name; | |||
| NestType = NestType.Node; | |||
| } | |||
| public Nest(IEnumerable<Nest<T>> values, string? name = null) | |||
| { | |||
| ListValue = values.ToList(); | |||
| Name = name; | |||
| NestType = NestType.List; | |||
| } | |||
| public Nest(Dictionary<string, Nest<T>> value, string? name = null) | |||
| { | |||
| DictValue = value; | |||
| Name = name; | |||
| NestType = NestType.Dictionary; | |||
| } | |||
| public Nest(Nest<T> other) | |||
| { | |||
| NestType = other.NestType; | |||
| Value = other.Value; | |||
| DictValue = other.DictValue; | |||
| ListValue = other.ListValue; | |||
| Name = other.Name; | |||
| } | |||
| public virtual IEnumerable<T> Flatten() | |||
| { | |||
| return FlattenInternal(this); | |||
| } | |||
| public virtual INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||
| { | |||
| return MapStructureInternal(func); | |||
| } | |||
| /// <summary> | |||
| /// Pack the flat items to a nested sequence by the template. | |||
| /// </summary> | |||
| /// <param name="flatItems"></param> | |||
| /// <returns></returns> | |||
| public virtual Nest<T> PackSequence(T[] flatItems) | |||
| { | |||
| if(flatItems.Length == 0) | |||
| { | |||
| return Nest<T>.Empty; | |||
| } | |||
| int index = 0; | |||
| return PackSequenceInternal(this, flatItems, ref index); | |||
| } | |||
| private static Nest<T> PackSequenceInternal(Nest<T> template, T[] flatItems, ref int index) | |||
| { | |||
| if(template.NestType == NestType.Node) | |||
| { | |||
| if(index >= flatItems.Length) | |||
| { | |||
| throw new InvalidArgumentError("The template and flat items are not matched."); | |||
| } | |||
| return new Nest<T>(flatItems[index++]); | |||
| } | |||
| else if(template.NestType == NestType.List) | |||
| { | |||
| List<Nest<T>> nestedObjects = new List<Nest<T>>(); | |||
| for (int i = 0; i < template.ListValue!.Count; i++) | |||
| { | |||
| nestedObjects.Add(PackSequenceInternal(template.ListValue![i], flatItems, ref index)); | |||
| } | |||
| return new Nest<T>(nestedObjects); | |||
| } | |||
| else if(template.NestType == NestType.Node) | |||
| { | |||
| Dictionary<string, Nest<T>> dict = new Dictionary<string, Nest<T>>(); | |||
| foreach(var (key, value) in template.DictValue!) | |||
| { | |||
| dict[key] = PackSequenceInternal(value, flatItems, ref index); | |||
| } | |||
| return new Nest<T>(dict); | |||
| } | |||
| // Consider Empty as invalid type. | |||
| throw new InvalidArgumentError("When using `PackSequenceAs`, the template cannot contain empty node."); | |||
| } | |||
| public virtual Nest<T> AsNest() | |||
| { | |||
| return this; | |||
| } | |||
| public virtual Nest<T> MergeWith(Nest<T>? other) | |||
| { | |||
| if(other is null || other == Nest<T>.Empty) | |||
| { | |||
| return this; | |||
| } | |||
| if(this == Nest<T>.Empty) | |||
| { | |||
| return other; | |||
| } | |||
| if(NestType == NestType.Node && other.NestType == NestType.Node) | |||
| { | |||
| return new Nest<T>(new Nest<T>[] { this, other }); | |||
| } | |||
| else if(NestType == NestType.List && other.NestType == NestType.List) | |||
| { | |||
| return new Nest<T>(this.ListValue!.Concat(other.ListValue!)); | |||
| } | |||
| else if(NestType == NestType.Dictionary && other.NestType == NestType.Dictionary) | |||
| { | |||
| return new Nest<T>(this.DictValue!.Concat(other.DictValue!).ToDictionary(x => x.Key, x => x.Value)); | |||
| } | |||
| else | |||
| { | |||
| return new Nest<T>(new Nest<T>[] { this, other }); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// To see if the nested object is really nested. Despite being called `Nest`, sometimes it's actually not | |||
| /// nested. For example, [1, 2, 3] is not nested, while [1, [2, 3]] is nested. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public bool IsNested() | |||
| { | |||
| if(NestType is NestType.Empty or NestType.Node) | |||
| { | |||
| return false; | |||
| } | |||
| else if(NestType is NestType.List) | |||
| { | |||
| foreach(var item in ListValue!) | |||
| { | |||
| if(item.NestType is NestType.List or NestType.Dictionary) | |||
| { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| else | |||
| { | |||
| foreach (var item in DictValue!.Values) | |||
| { | |||
| if (item.NestType is NestType.List or NestType.Dictionary) | |||
| { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| } | |||
| [Obsolete("The indexer of Tensors is not encouraged because it leads to unclear meanings.")] | |||
| public T this[int index] | |||
| { | |||
| get | |||
| { | |||
| bool success = FindInternal(this, index, out var result); | |||
| if (success) | |||
| { | |||
| return result; | |||
| } | |||
| else | |||
| { | |||
| throw new IndexOutOfRangeException(); | |||
| } | |||
| } | |||
| set | |||
| { | |||
| bool success = SetInternal(this, index, value); | |||
| if (!success) | |||
| { | |||
| throw new IndexOutOfRangeException(); | |||
| } | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// If the existing nested structure if of type `Nest[INestStructure[T]]`, we can reduce it | |||
| /// to `Nest[T]`. | |||
| /// </summary> | |||
| /// <typeparam name="TOut"></typeparam> | |||
| /// <param name="input"></param> | |||
| /// <returns></returns> | |||
| public static Nest<T> ReduceFrom<TOut>(INestStructure<TOut> input) where TOut: INestStructure<T> | |||
| { | |||
| var nested = input.AsNest(); | |||
| return ReduceInternal(nested); | |||
| } | |||
| private static Nest<T> ReduceInternal<TOut>(Nest<TOut> node) where TOut : INestStructure<T> | |||
| { | |||
| if(node.NestType == NestType.Empty) | |||
| { | |||
| return Nest<T>.Empty; | |||
| } | |||
| else if(node.NestType == NestType.Node) | |||
| { | |||
| return node.Value!.AsNest(); | |||
| } | |||
| else if(node.NestType == NestType.List) | |||
| { | |||
| return new Nest<T>(node.ListValue!.Select(x => ReduceInternal(x))); | |||
| } | |||
| else // Dictionary type | |||
| { | |||
| return new Nest<T>(node.DictValue!.ToDictionary(x => x.Key, x => ReduceInternal(x.Value))); | |||
| } | |||
| } | |||
| private static bool FindInternal(Nest<T> node, int index, out T? result) | |||
| { | |||
| if (node.NestType == NestType.Node) | |||
| { | |||
| if(index == 0) | |||
| { | |||
| result = node.Value!; | |||
| return true; | |||
| } | |||
| result = default(T); | |||
| return false; | |||
| } | |||
| else if (node.NestType == NestType.List) | |||
| { | |||
| foreach (var item in node.ListValue!) | |||
| { | |||
| if(index == 0) | |||
| { | |||
| return FindInternal(item, index, out result); | |||
| } | |||
| index--; | |||
| } | |||
| result = default(T); | |||
| return false; | |||
| } | |||
| else if(node.NestType == NestType.Dictionary) | |||
| { | |||
| foreach (var item in node.DictValue!.Values) | |||
| { | |||
| if (index == 0) | |||
| { | |||
| return FindInternal(item, index, out result); | |||
| } | |||
| index--; | |||
| } | |||
| result = default(T); | |||
| return false; | |||
| } | |||
| else | |||
| { | |||
| result = default(T); | |||
| return false; | |||
| } | |||
| } | |||
| private static bool SetInternal(Nest<T> node, int index, T newValue) | |||
| { | |||
| if (node.NestType == NestType.Node) | |||
| { | |||
| if (index == 0) | |||
| { | |||
| node.Value = newValue; | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| else if (node.NestType == NestType.List) | |||
| { | |||
| foreach (var item in node.ListValue!) | |||
| { | |||
| if (index == 0) | |||
| { | |||
| return SetInternal(item, index, newValue); | |||
| } | |||
| index--; | |||
| } | |||
| return false; | |||
| } | |||
| else if (node.NestType == NestType.Dictionary) | |||
| { | |||
| foreach (var item in node.DictValue!.Values) | |||
| { | |||
| if (index == 0) | |||
| { | |||
| return SetInternal(item, index, newValue); | |||
| } | |||
| index--; | |||
| } | |||
| return false; | |||
| } | |||
| else | |||
| { | |||
| return false; | |||
| } | |||
| } | |||
| private static IEnumerable<T> FlattenInternal(Nest<T> node) | |||
| { | |||
| if (node.NestType == NestType.Node) | |||
| { | |||
| yield return node.Value!; | |||
| } | |||
| else if (node.NestType == NestType.List) | |||
| { | |||
| foreach (var item in node.ListValue!) | |||
| { | |||
| foreach(var val in FlattenInternal(item)) | |||
| { | |||
| yield return val; | |||
| } | |||
| } | |||
| } | |||
| else if (node.NestType == NestType.Dictionary) | |||
| { | |||
| foreach (var item in node.DictValue!.Values) | |||
| { | |||
| foreach (var val in FlattenInternal(item)) | |||
| { | |||
| yield return val; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| private Nest<TOut> MapStructureInternal<TOut>(Func<T, TOut> func) | |||
| { | |||
| if (NestType == NestType.Node) | |||
| { | |||
| return new Nest<TOut>(func(Value!)); | |||
| } | |||
| else if (NestType == NestType.List) | |||
| { | |||
| List<Nest<TOut>> outs = new List<Nest<TOut>>(); | |||
| foreach (var item in ListValue!) | |||
| { | |||
| outs.Add(item.MapStructureInternal(func)); | |||
| } | |||
| return new Nest<TOut>(outs); | |||
| } | |||
| else if (NestType == NestType.Dictionary) | |||
| { | |||
| Dictionary<string, Nest<TOut>> outs = new Dictionary<string, Nest<TOut>>(); | |||
| foreach (var (key, value) in DictValue!) | |||
| { | |||
| outs.Add(key, value.MapStructureInternal(func)); | |||
| } | |||
| return new Nest<TOut>(outs); | |||
| } | |||
| else | |||
| { | |||
| return Nest<TOut>.Empty; | |||
| } | |||
| } | |||
| public IEnumerator<T> GetEnumerator() | |||
| { | |||
| return Flatten().GetEnumerator(); | |||
| } | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| { | |||
| return GetEnumerator(); | |||
| } | |||
| public override string ToString() | |||
| { | |||
| StringBuilder sb = new StringBuilder(); | |||
| sb.Append("("); | |||
| WriteString(this, sb); | |||
| sb.Append(")"); | |||
| return sb.ToString(); | |||
| } | |||
| private static void WriteString(Nest<T> node, StringBuilder sb) | |||
| { | |||
| if (!string.IsNullOrEmpty(node.Name)) | |||
| { | |||
| sb.Append($"{node.Name}: "); | |||
| } | |||
| if (node.NestType == NestType.Node) | |||
| { | |||
| sb.Append(node.Value!.ToString()); | |||
| } | |||
| else if (node.NestType == NestType.List) | |||
| { | |||
| sb.Append("["); | |||
| for(int i = 0; i < node.ListValue!.Count; i++) | |||
| { | |||
| WriteString(node.ListValue![i], sb); | |||
| if(i != node.ListValue!.Count - 1) | |||
| { | |||
| sb.Append(", "); | |||
| } | |||
| } | |||
| sb.Append("]"); | |||
| } | |||
| else if (node.NestType == NestType.Dictionary) | |||
| { | |||
| sb.Append("{"); | |||
| int count = node.DictValue!.Count; | |||
| int i = 0; | |||
| foreach (var (key, value) in node.DictValue!) | |||
| { | |||
| sb.Append($"{key}: "); | |||
| WriteString(value, sb); | |||
| if (i != count - 1) | |||
| { | |||
| sb.Append(", "); | |||
| } | |||
| i++; | |||
| } | |||
| sb.Append("}"); | |||
| } | |||
| else | |||
| { | |||
| sb.Append("<empty>"); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,99 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| public class NestDictionary<TKey, TValue> : INestStructure<TValue>, IDictionary<TKey, TValue> where TKey : notnull | |||
| { | |||
| public IDictionary<TKey, TValue> Value { get; set; } | |||
| public NestDictionary(IDictionary<TKey, TValue> dict) | |||
| { | |||
| Value = dict; | |||
| } | |||
| public IEnumerable<TValue> Flatten() | |||
| { | |||
| return Value.Select(x => x.Value); | |||
| } | |||
| public INestStructure<TOut> MapStructure<TOut>(Func<TValue, TOut> func) | |||
| { | |||
| return new NestList<TOut>(Value.Select(x => func(x.Value))); | |||
| } | |||
| public Nest<TValue> AsNest() | |||
| { | |||
| return new Nest<TValue>(Value.Values.Select(x => new Nest<TValue>(x))); | |||
| } | |||
| // Required IDictionary<TKey, TValue> members | |||
| public int Count => Value.Count; | |||
| public bool IsReadOnly => Value.IsReadOnly; | |||
| public ICollection<TKey> Keys => Value.Keys; | |||
| public ICollection<TValue> Values => Value.Values; | |||
| public void Add(TKey key, TValue value) | |||
| { | |||
| Value.Add(key, value); | |||
| } | |||
| public void Add(KeyValuePair<TKey, TValue> item) | |||
| { | |||
| Value.Add(item); | |||
| } | |||
| public void Clear() | |||
| { | |||
| Value.Clear(); | |||
| } | |||
| public bool Contains(KeyValuePair<TKey, TValue> item) | |||
| { | |||
| return Value.Contains(item); | |||
| } | |||
| public bool ContainsKey(TKey key) | |||
| { | |||
| return Value.ContainsKey(key); | |||
| } | |||
| public void CopyTo(KeyValuePair<TKey, TValue>[] array, int arrayIndex) | |||
| { | |||
| Value.CopyTo(array, arrayIndex); | |||
| } | |||
| public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator() | |||
| { | |||
| return Value.GetEnumerator(); | |||
| } | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| { | |||
| return GetEnumerator(); | |||
| } | |||
| public bool Remove(TKey key) | |||
| { | |||
| return Value.Remove(key); | |||
| } | |||
| public bool Remove(KeyValuePair<TKey, TValue> item) | |||
| { | |||
| return Value.Remove(item); | |||
| } | |||
| public bool TryGetValue(TKey key, out TValue value) | |||
| { | |||
| return Value.TryGetValue(key, out value); | |||
| } | |||
| // Optional IDictionary<TKey, TValue> members | |||
| public TValue this[TKey key] | |||
| { | |||
| get => Value[key]; | |||
| set => Value[key] = value; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,43 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| /// <summary> | |||
| /// The implementation of a list that support nest structure, in which the depth is 1. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| public sealed class NestList<T> : INestStructure<T>, IEnumerable<T> | |||
| { | |||
| public List<T> Value { get; set; } | |||
| public NestList(IEnumerable<T> values) | |||
| { | |||
| Value = new List<T>(values); | |||
| } | |||
| public IEnumerable<T> Flatten() | |||
| { | |||
| return Value; | |||
| } | |||
| public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||
| { | |||
| return new NestList<TOut>(Value.Select(x => func(x))); | |||
| } | |||
| public Nest<T> AsNest() | |||
| { | |||
| return new Nest<T>(Value.Select(x => new Nest<T>(x))); | |||
| } | |||
| // Enumerator implementation | |||
| public IEnumerator<T> GetEnumerator() | |||
| { | |||
| return Value.GetEnumerator(); | |||
| } | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| { | |||
| return GetEnumerator(); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,32 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| /// <summary> | |||
| /// A nested structure with only one element. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| public class NestNode<T> : INestStructure<T> | |||
| { | |||
| public T Value { get; set; } | |||
| public NestNode(T value) | |||
| { | |||
| Value = value; | |||
| } | |||
| public IEnumerable<T> Flatten() | |||
| { | |||
| yield return Value; | |||
| } | |||
| public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||
| { | |||
| return new NestNode<TOut>(func(Value)); | |||
| } | |||
| public Nest<T> AsNest() | |||
| { | |||
| return new Nest<T>(Value); | |||
| } | |||
| } | |||
| } | |||
| @@ -3,7 +3,7 @@ using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| namespace Tensorflow.Keras.Saving | |||
| namespace Tensorflow.Common.Types | |||
| { | |||
| public class TensorShapeConfig | |||
| { | |||
| @@ -1,17 +1,15 @@ | |||
| using Newtonsoft.Json; | |||
| using System.Collections.Generic; | |||
| using Tensorflow.Keras.Layers.Rnn; | |||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
| { | |||
| // TODO(Rinne): add regularizers. | |||
| public class RNNArgs : AutoSerializeLayerArgs | |||
| { | |||
| public interface IRnnArgCell : ILayer | |||
| { | |||
| object state_size { get; } | |||
| } | |||
| [JsonProperty("cell")] | |||
| // TODO: the cell should be serialized with `serialize_keras_object`. | |||
| public IRnnArgCell Cell { get; set; } = null; | |||
| public IRnnCell Cell { get; set; } = null; | |||
| [JsonProperty("return_sequences")] | |||
| public bool ReturnSequences { get; set; } = false; | |||
| [JsonProperty("return_state")] | |||
| @@ -34,6 +32,9 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
| public IInitializer KernelInitializer { get; set; } | |||
| public IInitializer RecurrentInitializer { get; set; } | |||
| public IInitializer BiasInitializer { get; set; } | |||
| public float Dropout { get; set; } = .0f; | |||
| public bool ZeroOutputForMask { get; set; } = false; | |||
| public float RecurrentDropout { get; set; } = .0f; | |||
| // kernel_regularizer=None, | |||
| // recurrent_regularizer=None, | |||
| @@ -0,0 +1,14 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
| { | |||
| public class RnnOptionalArgs: IOptionalArgs | |||
| { | |||
| public string Identifier => "Rnn"; | |||
| public Tensor Mask { get; set; } = null; | |||
| public Tensors Constants { get; set; } = null; | |||
| } | |||
| } | |||
| @@ -0,0 +1,29 @@ | |||
| using Newtonsoft.Json; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
| { | |||
| public class SimpleRNNCellArgs: AutoSerializeLayerArgs | |||
| { | |||
| [JsonProperty("units")] | |||
| public int Units { get; set; } | |||
| // TODO(Rinne): lack of initialized value of Activation. Merging keras | |||
| // into tf.net could resolve it. | |||
| [JsonProperty("activation")] | |||
| public Activation Activation { get; set; } | |||
| [JsonProperty("use_bias")] | |||
| public bool UseBias { get; set; } = true; | |||
| [JsonProperty("dropout")] | |||
| public float Dropout { get; set; } = .0f; | |||
| [JsonProperty("recurrent_dropout")] | |||
| public float RecurrentDropout { get; set; } = .0f; | |||
| [JsonProperty("kernel_initializer")] | |||
| public IInitializer KernelInitializer { get; set; } | |||
| [JsonProperty("recurrent_initializer")] | |||
| public IInitializer RecurrentInitializer { get; set; } | |||
| [JsonProperty("bias_initializer")] | |||
| public IInitializer BiasInitializer { get; set; } | |||
| } | |||
| } | |||
| @@ -1,4 +1,5 @@ | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.NumPy; | |||
| using Tensorflow.Training; | |||
| @@ -14,7 +15,7 @@ namespace Tensorflow.Keras | |||
| List<ILayer> Layers { get; } | |||
| List<INode> InboundNodes { get; } | |||
| List<INode> OutboundNodes { get; } | |||
| Tensors Apply(Tensors inputs, Tensor state = null, bool training = false); | |||
| Tensors Apply(Tensors inputs, Tensors states = null, bool training = false, IOptionalArgs? optional_args = null); | |||
| List<IVariableV1> TrainableVariables { get; } | |||
| List<IVariableV1> TrainableWeights { get; } | |||
| List<IVariableV1> NonTrainableWeights { get; } | |||
| @@ -0,0 +1,19 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Layers.Rnn | |||
| { | |||
| public interface IRnnCell: ILayer | |||
| { | |||
| GeneralizedTensorShape StateSize { get; } | |||
| GeneralizedTensorShape OutputSize { get; } | |||
| bool IsTFRnnCell { get; } | |||
| /// <summary> | |||
| /// Whether the optional RNN args are supported when appying the layer. | |||
| /// In other words, whether `Apply` is overwrited with process of `RnnOptionalArgs`. | |||
| /// </summary> | |||
| bool SupportOptionalArgs { get; } | |||
| } | |||
| } | |||
| @@ -0,0 +1,12 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.Layers.Rnn | |||
| { | |||
| public interface IStackedRnnCells : IRnnCell | |||
| { | |||
| int Count { get; } | |||
| IRnnCell this[int idx] { get; } | |||
| } | |||
| } | |||
| @@ -3,6 +3,7 @@ using Newtonsoft.Json; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Saving.Json | |||
| { | |||
| @@ -6,6 +6,7 @@ using System.Text; | |||
| using System.Diagnostics; | |||
| using OneOf.Types; | |||
| using Tensorflow.Keras.Saving.Json; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Saving | |||
| { | |||
| @@ -74,8 +74,3 @@ namespace Tensorflow | |||
| => IsScalar ? $"{axis[0]}" : $"({string.Join(", ", axis)})"; | |||
| } | |||
| } | |||
| namespace System.Runtime.CompilerServices | |||
| { | |||
| internal static class IsExternalInit { } | |||
| } | |||
| @@ -53,7 +53,7 @@ public class Orthogonal : IInitializer | |||
| // Compute the qr factorization | |||
| var (q, r) = tf.linalg.qr(a, full_matrices: false); | |||
| // Make Q uniform | |||
| var d = tf.linalg.tensor_diag_part(r); | |||
| var d = tf.linalg.tensor_diag_part(r.Single); | |||
| q *= tf.sign(d); | |||
| if (num_rows < num_cols) | |||
| @@ -11,6 +11,7 @@ namespace Tensorflow | |||
| /// Basic LSTM recurrent network cell. | |||
| /// The implementation is based on: http://arxiv.org/abs/1409.2329. | |||
| /// </summary> | |||
| [Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] | |||
| public class BasicLstmCell : LayerRnnCell | |||
| { | |||
| int _num_units; | |||
| @@ -20,6 +20,7 @@ using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| [Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] | |||
| public class BasicRnnCell : LayerRnnCell | |||
| { | |||
| int _num_units; | |||
| @@ -19,6 +19,7 @@ using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| [Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] | |||
| public class LayerRnnCell : RnnCell | |||
| { | |||
| protected InputSpec inputSpec; | |||
| @@ -16,10 +16,12 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Layers.Rnn; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.NumPy; | |||
| using Tensorflow.Operations; | |||
| @@ -50,7 +52,8 @@ namespace Tensorflow | |||
| /// matching structure of Tensors having shape `[batch_size].concatenate(s)` | |||
| /// for each `s` in `self.batch_size`. | |||
| /// </summary> | |||
| public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell | |||
| [Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")] | |||
| public abstract class RnnCell : ILayer, IRnnCell | |||
| { | |||
| /// <summary> | |||
| /// Attribute that indicates whether the cell is a TF RNN cell, due the slight | |||
| @@ -142,7 +145,7 @@ namespace Tensorflow | |||
| throw new NotImplementedException("_zero_state_tensors"); | |||
| } | |||
| public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| public Tensors Apply(Tensors inputs, Tensors state = null, bool is_training = false, IOptionalArgs? optional_args = null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| @@ -173,5 +176,14 @@ namespace Tensorflow | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public (Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public GeneralizedTensorShape StateSize => throw new NotImplementedException(); | |||
| public GeneralizedTensorShape OutputSize => throw new NotImplementedException(); | |||
| public bool IsTFRnnCell => throw new NotImplementedException(); | |||
| public bool SupportOptionalArgs => throw new NotImplementedException(); | |||
| } | |||
| } | |||
| @@ -17,6 +17,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Framework; | |||
| using static Tensorflow.Binding; | |||
| @@ -48,6 +49,7 @@ namespace Tensorflow.Operations | |||
| public override Tensor flow => _flow; | |||
| bool _clear_after_read; | |||
| List<Tensor> _tensor_array; | |||
| List<int> _previous_read_indices; | |||
| public _EagerTensorArray(TF_DataType dtype, Tensor size, bool dynamic_size = false, | |||
| bool clear_after_read = true, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | |||
| @@ -61,16 +63,20 @@ namespace Tensorflow.Operations | |||
| _dtype = dtype.as_base_dtype(); | |||
| _dynamic_size = dynamic_size; | |||
| _clear_after_read = clear_after_read; | |||
| _tensor_array = new List<Tensor>(); | |||
| _tensor_array = Enumerable.Repeat<Tensor>(null, size.numpy()).ToList(); | |||
| _previous_read_indices = new(); | |||
| } | |||
| public override TensorArray unstack(Tensor value, string name = null) | |||
| { | |||
| return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate | |||
| var tensors = array_ops.unstack(value, name: name); | |||
| if(tensors.Length > _tensor_array.Count && !_dynamic_size) | |||
| { | |||
| var num_elements = array_ops.shape(value)[0]; | |||
| return scatter(indices: math_ops.range(0, num_elements), value: value, name: name); | |||
| }); | |||
| throw new ValueError($"Cannot unstack {tensors.Length} tensors into a TensorArray of static size {_tensor_array.Count}"); | |||
| } | |||
| _tensor_array = tensors.ToList(); | |||
| // TODO(Rinne): revise the implementation. Here we should return `parent()`. | |||
| return this; | |||
| } | |||
| public TensorArray scatter(Tensor indices, Tensor value, string name = null) | |||
| @@ -116,9 +122,19 @@ namespace Tensorflow.Operations | |||
| _colocate_with.Add(value); | |||
| } | |||
| private Tensor _maybe_zero(int ix) | |||
| { | |||
| var val = _tensor_array[ix]; | |||
| if(val is null) | |||
| { | |||
| val = _tensor_array[ix] = array_ops.zeros(_element_shape, _dtype); | |||
| } | |||
| return val; | |||
| } | |||
| public override Tensor read<T>(T index, string name = null) | |||
| { | |||
| int index_int = -1; | |||
| int index_int; | |||
| if (index is int int_index) | |||
| index_int = int_index; | |||
| else if (index is Tensor tensor_index) | |||
| @@ -126,27 +142,75 @@ namespace Tensorflow.Operations | |||
| else | |||
| throw new ValueError(""); | |||
| if(index_int >= _tensor_array.Count) | |||
| { | |||
| throw new OutOfRangeError($"Tried to read from index {index_int} but array size is: {_tensor_array.Count} "); | |||
| } | |||
| var res = _tensor_array[index_int]; | |||
| if(res is null) | |||
| { | |||
| if (_previous_read_indices.Contains(index_int)) | |||
| { | |||
| throw new InvalidArgumentError($"Could not read index {index_int} twice because it was cleared after " + | |||
| $"a previous read (perhaps try setting clear_after_read = false?)"); | |||
| } | |||
| else | |||
| { | |||
| res = _maybe_zero(index_int); | |||
| } | |||
| } | |||
| if (_clear_after_read) | |||
| { | |||
| _tensor_array[index_int] = null; | |||
| _previous_read_indices.Add(index_int); | |||
| } | |||
| return _tensor_array[index_int]; | |||
| return res; | |||
| } | |||
| public override TensorArray write(Tensor index, Tensor value, string name = null) | |||
| { | |||
| if (_infer_shape) | |||
| _element_shape = _element_shape.merge_with(value.shape); | |||
| _tensor_array.add(value); | |||
| return this; | |||
| int index_int; | |||
| if(index is EagerTensor eager) | |||
| { | |||
| return write<Tensor>(eager.numpy(), value, name); | |||
| } | |||
| throw new InvalidArgumentError("The index is supposed to be an EagerTensor"); | |||
| } | |||
| public override TensorArray write<T>(int index, T value, string name = null) | |||
| { | |||
| var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | |||
| var index_tensor = ops.convert_to_tensor(index, name: "index"); | |||
| return write(index_tensor, value_tensor, name: name); | |||
| int size = _tensor_array.Count; | |||
| if(index >= size) | |||
| { | |||
| if (!_dynamic_size) | |||
| { | |||
| throw new OutOfRangeError($"Tried to write to index {index} but array is not resizeable and size " + | |||
| $"is: {size} "); | |||
| } | |||
| _tensor_array.AddRange(Enumerable.Repeat<Tensor>(null, index - size + 1)); | |||
| } | |||
| Tensor tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | |||
| if(_dtype != tensor.dtype) | |||
| { | |||
| throw new InvalidArgumentError($"TensorArray dtype is {_dtype.as_python_name()} but Op is " + | |||
| $"trying to write dtype {tensor.dtype.as_python_name()} "); | |||
| } | |||
| if (!_element_shape.is_compatible_with(tensor.shape)) | |||
| { | |||
| throw new ValueError($"Incompatible shape for value ({tensor.shape}), expected ({_element_shape})"); | |||
| } | |||
| if (_infer_shape) | |||
| { | |||
| _element_shape = _element_shape.merge_with(tensor.shape); | |||
| } | |||
| _tensor_array[index] = tensor; | |||
| return this; | |||
| } | |||
| private Tensor size(string name = null) | |||
| @@ -156,11 +220,26 @@ namespace Tensorflow.Operations | |||
| public override Tensor stack(string name = null) | |||
| { | |||
| ops.colocate_with(_handle); | |||
| return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate | |||
| if(_tensor_array.Count > 0) | |||
| { | |||
| return gather(math_ops.range(0, size()), name: name); | |||
| }); | |||
| for(int i = 0; i < _tensor_array.Count; i++) | |||
| { | |||
| _maybe_zero(i); | |||
| } | |||
| } | |||
| if(_tensor_array.Count == 0 && _element_shape.IsFullyDefined) | |||
| { | |||
| return ops.convert_to_tensor(new Shape(new long[] { 0 }.Concat(_element_shape.dims).ToArray()), name: name, dtype: _dtype); | |||
| } | |||
| else | |||
| { | |||
| return ops.convert_to_tensor(_tensor_array, name: name, dtype: _dtype); | |||
| } | |||
| //ops.colocate_with(_handle); | |||
| //return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate | |||
| //{ | |||
| // return gather(math_ops.range(0, size()), name: name); | |||
| //}); | |||
| } | |||
| public override Tensor gather(Tensor indices, string name = null) | |||
| @@ -30,7 +30,7 @@ namespace Tensorflow | |||
| name: name); | |||
| return tf.Context.ExecuteOp("PrintV2", name, new ExecuteOpArgs(formatted_string) | |||
| .SetAttributes(new { output_stream, end })); | |||
| .SetAttributes(new { output_stream, end })).SingleOrNull; | |||
| } | |||
| } | |||
| } | |||
| @@ -44,7 +44,7 @@ namespace Tensorflow | |||
| { | |||
| sorted = true | |||
| })); | |||
| return indices; | |||
| return indices.Single; | |||
| } | |||
| public static Tensor sort(Tensor values, Axis axis, string direction = "ASCENDING", string? name = null) | |||
| @@ -114,4 +114,9 @@ https://tensorflownet.readthedocs.io</Description> | |||
| <PackageReference Include="Protobuf.Text" Version="0.7.0" /> | |||
| <PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" /> | |||
| </ItemGroup> | |||
| <ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.0'"> | |||
| <PackageReference Include="IsExternalInit" Version="1.0.3" PrivateAssets="all" /> | |||
| <PackageReference Include="System.Memory" Version="4.5.4" PrivateAssets="all" /> | |||
| </ItemGroup> | |||
| </Project> | |||
| @@ -3,6 +3,7 @@ using System; | |||
| using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -13,157 +14,231 @@ namespace Tensorflow | |||
| /// and Tensor[] from Tensors implicitily. | |||
| /// It works for tuple and scalar as well. | |||
| /// </summary> | |||
| public class Tensors : IEnumerable<Tensor>, IDisposable | |||
| public sealed class Tensors : Nest<Tensor>, IDisposable | |||
| { | |||
| List<Tensor> items = new List<Tensor>(); | |||
| public TF_DataType dtype => items.First().dtype; | |||
| public Shape shape => items.First().shape; | |||
| public int rank => items.First().rank; | |||
| public Graph graph => items.First().graph; | |||
| public TF_DataType dtype => this.First().dtype; | |||
| public Shape shape => this.First().shape; | |||
| public int rank => this.First().rank; | |||
| public Graph graph => this.First().graph; | |||
| public bool IsList { get; set; } | |||
| public int Length => items.Count(); | |||
| public int Length => this.Count(); | |||
| /// <summary> | |||
| /// Return a Tensor if `Tensors` has only one tensor, otherwise throw an exception. | |||
| /// </summary> | |||
| public Tensor Single | |||
| { | |||
| get | |||
| { | |||
| if (Length != 1) | |||
| { | |||
| throw new ValueError("Tensors with more than one tensor cannot be " + | |||
| "implicitly converted to Tensor."); | |||
| } | |||
| return this.First(); | |||
| } | |||
| } | |||
| public Tensor this[int index] | |||
| /// <summary> | |||
| /// Return a Tensor if `Tensors` has only one tensor, and return null when `Tensors` is empty, | |||
| /// otherwise throw an exception. | |||
| /// </summary> | |||
| public Tensor? SingleOrNull | |||
| { | |||
| get => items[index]; | |||
| set => items[index] = value; | |||
| get | |||
| { | |||
| if (Length > 1) | |||
| { | |||
| throw new ValueError($"Tensors with {Length} tensor cannot be " + | |||
| "implicitly converted to Tensor."); | |||
| } | |||
| return this.FirstOrDefault(); | |||
| } | |||
| } | |||
| public Tensor this[params string[] slices] | |||
| => items.First()[slices]; | |||
| public Tensors(params Tensor[] tensors) | |||
| => this.First()[slices]; | |||
| public Tensors(Tensor tensor) : base(tensor) | |||
| { | |||
| } | |||
| private Tensors(Nest<Tensor> nested) : base(nested) | |||
| { | |||
| } | |||
| public Tensors(params Tensor[] tensors): base(tensors.Select(x => new Nest<Tensor>(x))) | |||
| { | |||
| } | |||
| public Tensors(IEnumerable<Tensor> tensors): base(tensors.Select(x => new Nest<Tensor>(x))) | |||
| { | |||
| items.AddRange(tensors); | |||
| } | |||
| public Tensors(IEnumerable<Tensor> tensors) | |||
| public Tensors(NDArray nd): base(ops.convert_to_tensor(nd)) | |||
| { | |||
| items.AddRange(tensors); | |||
| } | |||
| public Tensors(NDArray nd) | |||
| public bool IsSingle() | |||
| { | |||
| items.Add(ops.convert_to_tensor(nd)); | |||
| return Length == 1; | |||
| } | |||
| public IEnumerator<Tensor> GetEnumerator() | |||
| public new Tensors MergeWith(Nest<Tensor>? other) | |||
| { | |||
| foreach (var tensor in items) | |||
| yield return tensor; | |||
| return FromNest(base.MergeWith(other)); | |||
| } | |||
| [Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to add " + | |||
| "a tensor to `Tensors`, creating a new instance with your newly added tensor is a better choice.")] | |||
| public void Add(Tensor tensor) | |||
| => items.Add(tensor); | |||
| { | |||
| if(NestType == NestType.Dictionary) | |||
| { | |||
| throw new ValueError("Cannot add a tensor to dictionary type of nested tensors."); | |||
| } | |||
| else if(NestType == NestType.Node) | |||
| { | |||
| NestType = NestType.List; | |||
| ListValue = new() { new Nest<Tensor>(Value), new Nest<Tensor>(tensor) }; | |||
| Value = null; | |||
| } | |||
| else | |||
| { | |||
| ListValue.Add(new Nest<Tensor>(tensor)); | |||
| } | |||
| } | |||
| [Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to add " + | |||
| "some tensors to `Tensors`, creating a new instance with your newly added tensors is a better choice.")] | |||
| public void AddRange(IEnumerable<Tensor> tensors) | |||
| => items.AddRange(tensors); | |||
| { | |||
| if (NestType == NestType.Dictionary) | |||
| { | |||
| throw new ValueError("Cannot add a tensor to dictionary type of nested tensors."); | |||
| } | |||
| else if (NestType == NestType.Node) | |||
| { | |||
| NestType = NestType.List; | |||
| ListValue = new() { new Nest<Tensor>(Value) }; | |||
| ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x))); | |||
| Value = null; | |||
| } | |||
| else | |||
| { | |||
| ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x))); | |||
| } | |||
| } | |||
| [Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to insert " + | |||
| "a tensor to `Tensors`, creating a new instance with your newly added tensor is a better choice.")] | |||
| public void Insert(int index, Tensor tensor) | |||
| => items.Insert(index, tensor); | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| => GetEnumerator(); | |||
| { | |||
| if (NestType == NestType.List) | |||
| { | |||
| ListValue.Insert(index, new Nest<Tensor>(tensor)); | |||
| } | |||
| else if(NestType == NestType.Node) | |||
| { | |||
| NestType = NestType.List; | |||
| ListValue = new() { new Nest<Tensor>(Value) }; | |||
| ListValue.Insert(index, new Nest<Tensor>(tensor)); | |||
| Value = null; | |||
| } | |||
| else | |||
| { | |||
| throw new ValueError("Cannot add a tensor to dictionary type of nested tensors."); | |||
| } | |||
| } | |||
| public string[] StringData() | |||
| { | |||
| EnsureSingleTensor(this, "nnumpy"); | |||
| return this[0].StringData(); | |||
| return Single.StringData(); | |||
| } | |||
| public string StringData(int index) | |||
| { | |||
| EnsureSingleTensor(this, "nnumpy"); | |||
| return this[0].StringData(index); | |||
| return Single.StringData(index); | |||
| } | |||
| public NDArray numpy() | |||
| { | |||
| EnsureSingleTensor(this, "nnumpy"); | |||
| return this[0].numpy(); | |||
| return Single.numpy(); | |||
| } | |||
| [Obsolete] | |||
| public T[] ToArray<T>() where T: unmanaged | |||
| { | |||
| EnsureSingleTensor(this, $"ToArray<{typeof(T)}>"); | |||
| return this[0].ToArray<T>(); | |||
| return Single.ToArray<T>(); | |||
| } | |||
| #region Explicit Conversions | |||
| public unsafe static explicit operator bool(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to bool"); | |||
| return (bool)tensor[0]; | |||
| return (bool)tensor.Single; | |||
| } | |||
| public unsafe static explicit operator sbyte(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to sbyte"); | |||
| return (sbyte)tensor[0]; | |||
| return (sbyte)tensor.Single; | |||
| } | |||
| public unsafe static explicit operator byte(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to byte"); | |||
| return (byte)tensor[0]; | |||
| return (byte)tensor.Single; | |||
| } | |||
| public unsafe static explicit operator ushort(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to ushort"); | |||
| return (ushort)tensor[0]; | |||
| return (ushort)tensor.Single; | |||
| } | |||
| public unsafe static explicit operator short(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to short"); | |||
| return (short)tensor[0]; | |||
| return (short)tensor.Single; | |||
| } | |||
| public unsafe static explicit operator int(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to int"); | |||
| return (int)tensor[0]; | |||
| return (int)tensor.Single; | |||
| } | |||
| public unsafe static explicit operator uint(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to uint"); | |||
| return (uint)tensor[0]; | |||
| return (uint)tensor.Single; | |||
| } | |||
| public unsafe static explicit operator long(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to long"); | |||
| return (long)tensor[0]; | |||
| return (long)tensor.Single; | |||
| } | |||
| public unsafe static explicit operator ulong(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to ulong"); | |||
| return (ulong)tensor[0]; | |||
| return (ulong)tensor.Single; | |||
| } | |||
| public unsafe static explicit operator float(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to byte"); | |||
| return (byte)tensor[0]; | |||
| return (byte)tensor.Single; | |||
| } | |||
| public unsafe static explicit operator double(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to double"); | |||
| return (double)tensor[0]; | |||
| return (double)tensor.Single; | |||
| } | |||
| public unsafe static explicit operator string(Tensors tensor) | |||
| { | |||
| EnsureSingleTensor(tensor, "explicit conversion to string"); | |||
| return (string)tensor[0]; | |||
| return (string)tensor.Single; | |||
| } | |||
| public static explicit operator object[](Tensors tensors) | |||
| => tensors.items.ToArray(); | |||
| => tensors.Flatten().ToArray(); | |||
| #endregion | |||
| #region Implicit Conversions | |||
| @@ -183,56 +258,44 @@ namespace Tensorflow | |||
| public static implicit operator Tensors(List<Tensor> tensors) | |||
| => new Tensors(tensors.ToArray()); | |||
| public static implicit operator Tensor(Tensors tensors) | |||
| => tensors.FirstOrDefault(); | |||
| public static implicit operator Tensor(Tensors? tensors) | |||
| => tensors?.SingleOrNull; | |||
| public static implicit operator Tensor[](Tensors tensors) | |||
| => tensors.items.ToArray(); | |||
| => tensors.Flatten().ToArray(); | |||
| #endregion | |||
| public void Deconstruct(out Tensor a, out Tensor b) | |||
| public static Tensors? FromNest(Nest<Tensor> nested) | |||
| { | |||
| a = items[0]; | |||
| b = items[1]; | |||
| if(nested == Nest<Tensor>.Empty) | |||
| { | |||
| return null; | |||
| } | |||
| return new Tensors(nested); | |||
| } | |||
| private static void EnsureSingleTensor(Tensors tensors, string methodnName) | |||
| public void Deconstruct(out Tensor a, out Tensors? b) | |||
| { | |||
| if(tensors.Length == 0) | |||
| { | |||
| throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains no Tensor."); | |||
| } | |||
| else if(tensors.Length > 1) | |||
| { | |||
| throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains more than one Tensor."); | |||
| } | |||
| a = this.First(); | |||
| b = Length == 1? null : new Tensors(this.Skip(1)); | |||
| } | |||
| public override string ToString() | |||
| { | |||
| if(items.Count == 1) | |||
| if(Length == 1) | |||
| { | |||
| return items[0].ToString(); | |||
| return this.First().ToString(); | |||
| } | |||
| else | |||
| { | |||
| StringBuilder sb = new StringBuilder(); | |||
| sb.Append($"Totally {items.Count} tensors, which are {string.Join(", ", items.Select(x => x.name))}\n[\n"); | |||
| for(int i = 0; i < items.Count; i++) | |||
| { | |||
| var tensor = items[i]; | |||
| sb.Append($"Tensor {i}({tensor.name}): {tensor.ToString()}\n"); | |||
| } | |||
| sb.Append("]\n"); | |||
| return sb.ToString(); | |||
| return $"Totally {Length} tensors: {base.ToString()}"; | |||
| } | |||
| } | |||
| public void Dispose() | |||
| { | |||
| foreach (var item in items) | |||
| item.Dispose(); | |||
| foreach (var tensor in this) | |||
| tensor.Dispose(); | |||
| } | |||
| } | |||
| } | |||
| @@ -36,6 +36,7 @@ namespace Tensorflow.Util | |||
| // (np.array([3, 4]), tf.constant([3, 4])))` | |||
| // | |||
| [Obsolete] | |||
| public static class nest | |||
| { | |||
| @@ -20,8 +20,11 @@ using System.Linq; | |||
| using System.Collections.Generic; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.Graphs; | |||
| using Tensorflow.Common.Extensions; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.Graphs.SubGraphUtility; | |||
| using Tensorflow.Util; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras | |||
| { | |||
| @@ -450,5 +453,535 @@ namespace Tensorflow.Keras | |||
| return x; | |||
| } | |||
| public (Tensors, Tensors, Tensors) rnn( | |||
| Func<Tensors, Tensors, (Tensors, Tensors)> step_function, // args:inputs, states, return:output, new_states | |||
| Tensors inputs, // inputs is a tuple of tensors (one per input sequence) | |||
| Tensors initial_states, | |||
| bool go_backwards = false, | |||
| Tensor? mask = null, | |||
| Tensors? constants = null, | |||
| bool unroll = false, | |||
| Tensors? input_length = null, // An integer or a 1-D Tensor,depending on whether the time dimension is fixed-length or not | |||
| bool time_major = false, | |||
| bool zero_output_for_mask = false, | |||
| bool return_all_outputs = true) | |||
| { | |||
| Tensor swap_batch_timestep(Tensor input_t) | |||
| { | |||
| var axes = Enumerable.Range(0, input_t.rank).ToArray(); | |||
| axes[0] = 1; | |||
| axes[1] = 0; | |||
| return tf.transpose(input_t, axes); | |||
| } | |||
| if (!time_major) | |||
| { | |||
| inputs = Nest.MapStructure(swap_batch_timestep, inputs).ToTensors(); | |||
| } | |||
| var flatted_inptus = Nest.Flatten(inputs).ToList(); | |||
| var first_flatted_input = flatted_inptus[0]; | |||
| var time_steps = first_flatted_input.shape[0]; | |||
| var batch = first_flatted_input.shape[1]; | |||
| var time_steps_t = (int)first_flatted_input.shape[0]; | |||
| foreach (var input_ in flatted_inptus) | |||
| { | |||
| input_.shape.with_rank_at_least(3); | |||
| } | |||
| if (mask != null) | |||
| { | |||
| if (mask.dtype != TF_DataType.TF_BOOL) | |||
| { | |||
| mask = tf.cast(mask, TF_DataType.TF_BOOL); | |||
| } | |||
| if (mask.rank == 2) | |||
| { | |||
| mask = tf.expand_dims(mask, -1); | |||
| } | |||
| if (!time_major) | |||
| { | |||
| mask = swap_batch_timestep(mask); | |||
| } | |||
| } | |||
| // tf.where needs its condition tensor to be the same shape as its two | |||
| // result tensors, but in our case the condition (mask) tensor is | |||
| // (nsamples, 1), and inputs are (nsamples, ndimensions) or even more. | |||
| // So we need to broadcast the mask to match the shape of inputs. | |||
| // That's what the tile call does, it just repeats the mask along its | |||
| // second dimension n times. | |||
| Tensors _expand_mask(Tensors mask_t, Tensors input_t, int fixed_dim = 1) | |||
| { | |||
| if (!mask_t.IsSingle()) | |||
| { | |||
| throw new ValueError($"mask_t is expected to be tensor, but got {mask_t}"); | |||
| } | |||
| if (!input_t.IsSingle()) | |||
| { | |||
| throw new ValueError($"input_t is expected to be tensor, but got {input_t}"); | |||
| } | |||
| var rank_diff = input_t.rank - mask_t.rank; | |||
| for (int i = 0; i < rank_diff; i++) | |||
| { | |||
| mask_t = tf.expand_dims(mask_t, -1); | |||
| } | |||
| var multiples = Enumerable.Repeat(1, fixed_dim).ToArray().concat(input_t.shape.as_int_list().ToList().GetRange(fixed_dim, input_t.rank)); | |||
| return tf.tile(mask_t, multiples); | |||
| } | |||
| Tensors outputs = new Tensors(); | |||
| Tensors output_time_zero = new Tensors(); | |||
| Tensors last_output = new Tensors(); | |||
| Tensors new_states = new Tensors(); | |||
| if (unroll) | |||
| { | |||
| if (time_steps == 0) | |||
| { | |||
| throw new ValueError("Unrolling requires a fixed number of timesteps."); | |||
| } | |||
| // Process the input tensors. The input tensor need to be split on the | |||
| // time_step dim, and reverse if go_backwards is True. In the case of | |||
| // nested input, the input is flattened and then transformed | |||
| // individually. The result of this will be a tuple of lists, each of | |||
| // the item in tuple is list of the tensor with shape (batch, feature) | |||
| // TODO(Wanglongzhi2001),step_func接受的第二个参数为List,但是最后却用的tuple | |||
| //var states = Tuple.Create(initial_states); | |||
| var states = initial_states; | |||
| var successive_states = new Tensors(); | |||
| var successive_outputs = new Tensors(); | |||
| // Process the input tensors. The input tensor need to be split on the | |||
| // time_step dim, and reverse if go_backwards is True. In the case of | |||
| // nested input, the input is flattened and then transformed | |||
| // individually. The result of this will be a tuple of lists, each of | |||
| // the item in tuple is list of the tensor with shape (batch, feature) | |||
| Tensors _process_single_input_t(Tensor input_t) | |||
| { | |||
| var unstaked_input_t = array_ops.unstack(input_t); // unstack for time_step dim | |||
| if (go_backwards) | |||
| { | |||
| unstaked_input_t = unstaked_input_t.Reverse().ToArray(); | |||
| } | |||
| return unstaked_input_t; | |||
| } | |||
| // TODO(Wanglongzhi2001) | |||
| Tensors processed_input; | |||
| if (!inputs.IsSingle()) | |||
| { | |||
| processed_input = inputs.MapStructure(_process_single_input_t).ReduceTo<Tensors, Tensor>().ToTensors(); | |||
| } | |||
| else | |||
| { | |||
| processed_input = _process_single_input_t(inputs); | |||
| } | |||
| object _get_input_tensor(int time) | |||
| { | |||
| List<Tensor> inp = new List<Tensor>(); | |||
| foreach (var t_ in processed_input) | |||
| { | |||
| inp.Add(t_[time]); | |||
| } | |||
| return Nest.PackSequenceAs(inputs, inp); | |||
| } | |||
| if (mask != null) | |||
| { | |||
| var mask_list = tf.unstack(mask); | |||
| if (go_backwards) | |||
| { | |||
| mask_list.Reverse(); | |||
| } | |||
| for (int i = 0; i < time_steps; i++) | |||
| { | |||
| // TODO(Wanglongzhi2001),deal with _get_input_tensor | |||
| var inp = _get_input_tensor(i); | |||
| var mask_t = mask_list[i]; | |||
| // TODO | |||
| var (output, newStates) = step_function((Tensors)inp, states.MergeWith(constants)); | |||
| var tiled_mask_t = _expand_mask(mask_t, output); | |||
| Tensors prev_output; | |||
| if (successive_outputs == null) | |||
| { | |||
| prev_output = tf.zeros_like(output); | |||
| } | |||
| else | |||
| { | |||
| prev_output = successive_outputs[successive_outputs.Length - 1]; | |||
| } | |||
| output = tf.where(tiled_mask_t, output, prev_output); | |||
| var flat_states = Nest.Flatten(states).ToList(); | |||
| var flat_new_states = Nest.Flatten(newStates).ToList(); | |||
| var tiledMaskT = flat_states | |||
| .Select(s => _expand_mask(mask_t, s)) | |||
| .ToArray(); | |||
| var tuple = Tuple.Create(tiledMaskT); | |||
| List<Tensor> flat_final_states = new List<Tensor>(); | |||
| foreach (var (m, s, ps) in zip(tiled_mask_t.ToList(), flat_new_states, flat_states)) | |||
| { | |||
| flat_final_states.Add(tf.where(m, s, ps)); | |||
| } | |||
| states = Nest.PackSequenceAs(states, flat_final_states).ToTensors(); | |||
| if (return_all_outputs) | |||
| { | |||
| successive_outputs.Add(output); | |||
| successive_states.Add(states); | |||
| } | |||
| else | |||
| { | |||
| successive_outputs = new Tensors { output }; | |||
| successive_states = new Tensors { states }; | |||
| } | |||
| } | |||
| last_output = successive_outputs[successive_outputs.Length - 1]; | |||
| new_states = successive_states[successive_states.Length - 1]; | |||
| outputs = tf.stack(successive_outputs); | |||
| if (zero_output_for_mask) | |||
| { | |||
| last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output)); | |||
| outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs)); | |||
| } | |||
| else // mask is null | |||
| { | |||
| for (int i = 0; i < time_steps; i++) | |||
| { | |||
| var inp = _get_input_tensor(i); | |||
| var (output, newStates) = step_function((Tensors)inp, states.MergeWith(constants)); | |||
| states = newStates; | |||
| if (return_all_outputs) | |||
| { | |||
| successive_outputs.Add(output); | |||
| successive_states.Add(newStates); | |||
| } | |||
| else | |||
| { | |||
| successive_outputs = new Tensors { output }; | |||
| successive_states = new Tensors { newStates }; | |||
| } | |||
| } | |||
| last_output = successive_outputs[successive_outputs.Length - 1]; | |||
| new_states = successive_states[successive_states.Length - 1]; | |||
| outputs = tf.stack(successive_outputs); | |||
| } | |||
| } | |||
| } | |||
| else // unroll == false | |||
| { | |||
| var states = initial_states; | |||
| // Create input tensor array, if the inputs is nested tensors, then it | |||
| // will be flattened first, and tensor array will be created one per | |||
| // flattened tensor. | |||
| var input_ta = new List<TensorArray>(); | |||
| for (int i = 0; i < flatted_inptus.Count; i++) | |||
| { | |||
| input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_steps_t)); | |||
| } | |||
| foreach(var (ta, input_) in zip(input_ta, flatted_inptus)) | |||
| { | |||
| if (!go_backwards) | |||
| { | |||
| ta.unstack(input_); | |||
| } | |||
| else | |||
| { | |||
| ta.unstack(reverse(input_, 0)); | |||
| } | |||
| } | |||
| // Get the time(0) input and compute the output for that, the output will | |||
| // be used to determine the dtype of output tensor array. Don't read from | |||
| // input_ta due to TensorArray clear_after_read default to True. | |||
| var inps = new Tensors(); | |||
| foreach (var inp in flatted_inptus) | |||
| { | |||
| inps.Add(inp[0]); | |||
| } | |||
| var input_time_zero = Nest.PackSequenceAs(inputs, inps).ToTensors(); | |||
| // output_time_zero is used to determine the cell output shape and its | |||
| // dtype. the value is discarded. | |||
| (output_time_zero, _) = step_function((Tensor)input_time_zero, | |||
| constants is null ? initial_states : initial_states.MergeWith(constants)); | |||
| int output_ta_size = return_all_outputs ? time_steps_t : 1; | |||
| var output_ta = new List<TensorArray>(); | |||
| for (int i = 0; i < output_time_zero.ToList().Count; i++) | |||
| { | |||
| var Out = output_time_zero.ToList()[i]; | |||
| output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape)); | |||
| } | |||
| var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time"); | |||
| Func<Tensor, Tensor>? masking_fn; | |||
| Func<Tensors, Tensors, Tensors, Tensors>? compute_masked_output = null; | |||
| if (mask != null) | |||
| { | |||
| if (go_backwards) | |||
| { | |||
| mask = tf.reverse(mask, axis: new[] { 0 }); | |||
| } | |||
| var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_steps_t); | |||
| mask_ta = mask_ta.unstack(mask); | |||
| masking_fn = (time) => | |||
| { | |||
| return mask_ta.read(time); | |||
| }; | |||
| compute_masked_output = (mask_t, flat_out, flat_mask) => | |||
| { | |||
| var tiled_mask_t = new Tensors(); | |||
| foreach (var o in flat_out) | |||
| { | |||
| tiled_mask_t.Add(_expand_mask(mask_t, o, fixed_dim: mask_t.rank)); | |||
| } | |||
| Tensors res = new Tensors(); | |||
| foreach (var (m, o, fm) in zip(tiled_mask_t.ToList(), flat_out.ToList(), flat_mask.ToList())) | |||
| { | |||
| res.Add(tf.where(m, o, fm)); | |||
| } | |||
| return res; | |||
| }; | |||
| } | |||
| // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)? | |||
| else if (input_length is Tensor) | |||
| { | |||
| if (go_backwards) | |||
| { | |||
| var max_len = tf.reduce_max(input_length, axis: 0); | |||
| var rev_input_length = tf.subtract(max_len - 1, input_length); | |||
| masking_fn = (time) => | |||
| { | |||
| return tf.less(rev_input_length, time); | |||
| }; | |||
| } | |||
| else | |||
| { | |||
| masking_fn = (time) => | |||
| { | |||
| return tf.greater(input_length, time); | |||
| }; | |||
| } | |||
| compute_masked_output = (mask_t, flat_out, flat_mask) => | |||
| { | |||
| var res = new List<Tensor>(); | |||
| foreach (var (o, zo) in zip(flat_out, flat_mask)) | |||
| { | |||
| res.Add(tf.where(mask_t, o, zo)); | |||
| } | |||
| return res; | |||
| }; | |||
| } | |||
| else | |||
| { | |||
| masking_fn = null; | |||
| } | |||
| Func<Tensor, Tensor> cond = (time) => (time < time_steps_t); | |||
| int parallel_iterations = 32; | |||
| if (masking_fn != null) | |||
| { | |||
| // Mask for the T output will be base on the output of T - 1. In the | |||
| // case T = 0, a zero filled tensor will be used. | |||
| var flat_zero_output = new Tensors(); | |||
| foreach (var o in Nest.Flatten(output_time_zero)) | |||
| { | |||
| flat_zero_output.Add(tf.zeros_like(o)); | |||
| } | |||
| var prev_output = flat_zero_output; | |||
| var output_ta_t = output_ta; | |||
| Tensor _step(Tensor time) | |||
| { | |||
| /* | |||
| RNN step function. | |||
| Args: | |||
| time: Current timestep value. | |||
| output_ta_t: TensorArray. | |||
| prev_output: tuple of outputs from time - 1. | |||
| *states: List of states. | |||
| Returns: | |||
| Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)` | |||
| */ | |||
| var flat_current_input = input_ta.Select(x => x.read(time)).ToList(); | |||
| // maybe set shape | |||
| // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type | |||
| var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors(); | |||
| var mask_t = masking_fn(time); | |||
| var (output, new_states_internal) = step_function(current_input, states.MergeWith(constants)); | |||
| // mask output | |||
| var flat_output = Nest.Flatten(output).ToList(); | |||
| var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList(); | |||
| // TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type | |||
| var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output); | |||
| // mask states | |||
| var flat_state = states.ToList(); | |||
| var flat_new_state = new_states_internal.ToList(); | |||
| foreach (var (state, new_state) in zip(flat_state, flat_new_state)) | |||
| { | |||
| if (new_state is Tensor) | |||
| { | |||
| new_state.shape = state.shape; | |||
| } | |||
| } | |||
| var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state); | |||
| new_states_internal = Nest.PackSequenceAs(new_states, flat_final_state).ToTensors(); | |||
| var ta_index_to_write = return_all_outputs ? time : tf.constant(0); | |||
| // TODO(Wanglongzhi2001),deal with zip output_ta_t | |||
| foreach (var (ta, Out) in zip(output_ta_t, flat_new_output)) | |||
| { | |||
| output_ta_t.Add(ta.write(ta_index_to_write, Out)); | |||
| } | |||
| new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors(); | |||
| output_ta = output_ta_t; | |||
| new_states = new_states_internal; | |||
| return time + 1; | |||
| } | |||
| var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations); | |||
| } | |||
| else | |||
| { | |||
| var output_ta_t = output_ta; | |||
| new_states = states; | |||
| Tensor _step(Tensor time) | |||
| { | |||
| var flat_current_input = input_ta.Select(x => x.read(time)).ToList(); | |||
| // maybe set shape | |||
| // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type | |||
| var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors(); | |||
| var (output, new_states_internal) = step_function(current_input, new_states.MergeWith(constants)); | |||
| var flat_state = new_states.Flatten().ToList(); | |||
| var flat_new_state = new_states_internal.Flatten().ToList(); | |||
| foreach (var (state, new_state) in zip(flat_state, flat_new_state)) | |||
| { | |||
| if (new_state is Tensor) | |||
| { | |||
| new_state.shape = state.shape; | |||
| } | |||
| } | |||
| var flat_output = Nest.Flatten(output); | |||
| var ta_index_to_write = return_all_outputs ? time : tf.constant(0); | |||
| output_ta_t = zip(output_ta_t, flat_output).Select(item => | |||
| { | |||
| var (ta, out_) = item; | |||
| return ta.write(ta_index_to_write, out_); | |||
| }).ToList(); | |||
| new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors(); | |||
| output_ta = output_ta_t; | |||
| new_states = new_states_internal; | |||
| return time + 1; | |||
| } | |||
| var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations); | |||
| } | |||
| //Tensors outputs = new Tensors(); | |||
| foreach (var o in output_ta) | |||
| { | |||
| outputs.Add(o.stack()); | |||
| } | |||
| foreach (var o in outputs) | |||
| { | |||
| last_output.Add(o[-1]); | |||
| } | |||
| outputs = Nest.PackSequenceAs(output_time_zero, outputs).ToTensors(); | |||
| last_output = Nest.PackSequenceAs(output_time_zero, last_output).ToTensors(); | |||
| } | |||
| Func<Tensor, Tensor> set_shape; | |||
| set_shape = (output_) => | |||
| { | |||
| if (output_ is Tensor) | |||
| { | |||
| var shape = output_.shape.as_int_list(); | |||
| if (return_all_outputs) | |||
| { | |||
| shape[0] = (int)time_steps; | |||
| } | |||
| else | |||
| { | |||
| shape[0] = 1; | |||
| } | |||
| shape[1] = (int)batch; | |||
| output_.shape = shape; | |||
| } | |||
| return output_; | |||
| }; | |||
| outputs = Nest.MapStructure(set_shape, outputs).ToTensors(); | |||
| if (!time_major) | |||
| { | |||
| outputs = Nest.MapStructure(swap_batch_timestep, outputs).ToTensors(); | |||
| } | |||
| return (last_output, outputs, new_states); | |||
| } | |||
| public Tensor reverse(Tensor input, int axis) | |||
| { | |||
| return reverse(input, new int[] { axis }); | |||
| } | |||
| public Tensor reverse(Tensor input, int[] axes) | |||
| { | |||
| return tf.reverse(input, axes); | |||
| } | |||
| public Tensor maybe_convert_to_ragged(bool is_ragged_output, Tensor output, int nested_row_lengths, bool go_backwards = false) | |||
| { | |||
| if (!is_ragged_output) | |||
| { | |||
| return output; | |||
| } | |||
| throw new NotImplementedException("Not implemented currently, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,6 +1,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Saving.SavedModel; | |||
| using Tensorflow.Keras.Utils; | |||
| @@ -81,7 +82,7 @@ namespace Tensorflow.Keras.Engine | |||
| } | |||
| else | |||
| { | |||
| _buildInputShape = new Saving.TensorShapeConfig(); | |||
| _buildInputShape = new TensorShapeConfig(); | |||
| } | |||
| if (outputs.Any(x => x.KerasHistory == null)) | |||
| @@ -325,7 +326,7 @@ namespace Tensorflow.Keras.Engine | |||
| nodes_in_decreasing_depth.append(node); | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| var tensor_dict = new Dictionary<long, Queue<Tensor>>(); | |||
| // map input values | |||
| @@ -1,4 +1,5 @@ | |||
| using System.Threading; | |||
| using Tensorflow.Common.Types; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.Engine | |||
| @@ -8,11 +9,11 @@ namespace Tensorflow.Keras.Engine | |||
| /// <summary> | |||
| /// Wraps `call`, applying pre- and post-processing steps. | |||
| /// </summary> | |||
| /// <param name="input"></param> | |||
| /// <param name="inputs"></param> | |||
| /// <param name="state"></param> | |||
| /// <param name="training"></param> | |||
| /// <returns></returns> | |||
| public Tensors Apply(Tensors inputs, Tensor state = null, bool training = false) | |||
| public virtual Tensors Apply(Tensors inputs, Tensors states = null, bool training = false, IOptionalArgs? optional_args = null) | |||
| { | |||
| if (callContext.Value == null) | |||
| callContext.Value = new CallContext(); | |||
| @@ -30,7 +31,7 @@ namespace Tensorflow.Keras.Engine | |||
| if (!built) | |||
| MaybeBuild(inputs); | |||
| var outputs = Call(inputs, state: state, training: training); | |||
| var outputs = Call(inputs, state: states, training: training); | |||
| // memory leak | |||
| // _set_connectivity_metadata_(inputs, outputs); | |||
| @@ -32,7 +32,7 @@ using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow.Framework; | |||
| using Tensorflow.Sessions; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Engine | |||
| { | |||
| @@ -332,7 +332,7 @@ namespace Tensorflow.Keras.Engine | |||
| /// <param name="state"></param> | |||
| /// <param name="training"></param> | |||
| /// <returns></returns> | |||
| protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected virtual Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| if(ReplacedCall is not null) | |||
| { | |||
| @@ -1,8 +1,8 @@ | |||
| using System.Diagnostics; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Framework.Models; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Losses; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.Keras.Saving.SavedModel; | |||
| using Tensorflow.Keras.Utils; | |||
| using Tensorflow.Train; | |||
| @@ -21,6 +21,7 @@ using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Layers; | |||
| using Tensorflow.Keras.Utils; | |||
| using static Tensorflow.KerasApi; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Engine | |||
| { | |||
| @@ -143,7 +144,7 @@ namespace Tensorflow.Keras.Engine | |||
| } | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| if (!_has_explicit_input_shape) | |||
| { | |||
| @@ -1,6 +1,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving; | |||
| @@ -29,7 +30,7 @@ namespace Tensorflow.Keras.Layers { | |||
| base.build(input_shape); | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| Tensor output = inputs; | |||
| output = tf.where(output > 0f, output, | |||
| @@ -4,7 +4,7 @@ using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Layers { | |||
| public class Exponential : Layer | |||
| @@ -17,7 +17,7 @@ namespace Tensorflow.Keras.Layers { | |||
| { | |||
| base.build(input_shape); | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| Tensor output = inputs; | |||
| return tf.exp(output); | |||
| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Common.Types; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.Layers { | |||
| @@ -10,7 +11,7 @@ namespace Tensorflow.Keras.Layers { | |||
| public HardSigmoid ( LayerArgs args ) : base(args) { | |||
| // hard sigmoid has no arguments | |||
| } | |||
| protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | |||
| protected override Tensors Call ( Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null ) { | |||
| Tensor x = inputs; | |||
| return tf.clip_by_value( | |||
| tf.add(tf.multiply(x, 0.2f), 0.5f), 0f, 1f); | |||
| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Common.Types; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.Layers | |||
| @@ -19,7 +20,7 @@ namespace Tensorflow.Keras.Layers | |||
| this.args = args; | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| return tf.nn.leaky_relu(inputs, alpha: alpha); | |||
| } | |||
| @@ -1,6 +1,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving; | |||
| @@ -22,7 +23,7 @@ namespace Tensorflow.Keras.Layers { | |||
| } | |||
| base.build(input_shape); | |||
| } | |||
| protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | |||
| protected override Tensors Call ( Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { | |||
| Tensor output = inputs; | |||
| return tf.where(output > 0f, | |||
| tf.multiply(scale, output), | |||
| @@ -1,6 +1,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using static Tensorflow.Binding; | |||
| @@ -11,8 +12,8 @@ namespace Tensorflow.Keras.Layers { | |||
| public Softmax ( SoftmaxArgs args ) : base(args) { | |||
| axis = args.axis; | |||
| } | |||
| protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | |||
| Tensor x = inputs.Length == 2 ? inputs + ((1.0 - tf.cast(inputs[1], inputs.dtype)) * 1e-9) | |||
| protected override Tensors Call ( Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { | |||
| Tensor x = inputs.Length == 2 ? inputs[0] + ((1.0 - tf.cast(inputs[1], inputs.dtype)) * 1e-9) | |||
| : inputs; | |||
| Tensor e = tf.exp(tf.sub(x, tf.reduce_max(x, axis: this.axis, keepdims: true))); | |||
| Tensor s = tf.reduce_sum(e, axis: this.axis, keepdims: true); | |||
| @@ -1,6 +1,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using static Tensorflow.Binding; | |||
| @@ -10,7 +11,7 @@ namespace Tensorflow.Keras.Layers { | |||
| public Softplus ( LayerArgs args ) : base(args) { | |||
| // Softplus has no arguments | |||
| } | |||
| protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | |||
| protected override Tensors Call ( Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { | |||
| Tensor x = inputs; | |||
| return tf.log( | |||
| tf.add(tf.exp(x), 1f)); | |||
| @@ -1,6 +1,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using static Tensorflow.Binding; | |||
| @@ -10,7 +11,7 @@ namespace Tensorflow.Keras.Layers { | |||
| public Softsign ( LayerArgs args ) : base(args) { | |||
| // Softsign has no arguments | |||
| } | |||
| protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | |||
| protected override Tensors Call ( Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { | |||
| Tensor x = inputs; | |||
| // x / (abs(x) + 1) | |||
| return tf.div(x, tf.add(1f, tf.abs(x))); | |||
| @@ -1,6 +1,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using static Tensorflow.Binding; | |||
| @@ -10,7 +11,7 @@ namespace Tensorflow.Keras.Layers { | |||
| public Swish ( LayerArgs args ) : base(args) { | |||
| // Swish has no arguments | |||
| } | |||
| protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | |||
| protected override Tensors Call ( Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { | |||
| Tensor x = inputs; | |||
| // x / (1 + exp(-x)) | |||
| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Common.Types; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.Layers | |||
| @@ -13,7 +14,7 @@ namespace Tensorflow.Keras.Layers | |||
| { | |||
| // Tanh has no arguments | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| Tensor x = inputs; | |||
| @@ -6,6 +6,7 @@ using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.Common.Types; | |||
| /// <summary> | |||
| /// Base class for attention layers that can be used in sequence DNN/CNN models. | |||
| @@ -114,7 +115,7 @@ namespace Tensorflow.Keras.Layers | |||
| return (tf.linalg.einsum("bij,bjk->bik", (weights, value)), weights); | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| Tensors _inp; | |||
| Tensors _mask = null; | |||
| @@ -6,6 +6,7 @@ using static Tensorflow.Binding; | |||
| using static Tensorflow.KerasApi; | |||
| using System; | |||
| using System.Linq; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| @@ -252,7 +253,7 @@ namespace Tensorflow.Keras.Layers | |||
| return (attention_output, attention_scores); | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| Tensors _inp; | |||
| Tensor _mask = null; | |||
| @@ -349,7 +350,7 @@ namespace Tensorflow.Keras.Layers | |||
| //} | |||
| if (return_attention_scores) | |||
| return (attention_output, attention_scores); | |||
| return (attention_output, attention_scores.Single); | |||
| return attention_output; | |||
| } | |||
| } | |||
| @@ -20,6 +20,7 @@ using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Utils; | |||
| using static Tensorflow.KerasApi; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| @@ -83,7 +84,7 @@ namespace Tensorflow.Keras.Layers | |||
| _buildInputShape = input_shape; | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| var inputs_shape = array_ops.shape(inputs); | |||
| var batch_size = inputs_shape[0]; | |||
| @@ -17,6 +17,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving; | |||
| @@ -103,7 +104,7 @@ namespace Tensorflow.Keras.Layers | |||
| _buildInputShape = input_shape; | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = false) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = false, IOptionalArgs? optional_args = null) | |||
| { | |||
| var outputs = _convolution_op.Apply(inputs, kernel.AsTensor()); | |||
| if (use_bias) | |||
| @@ -18,6 +18,7 @@ using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving; | |||
| @@ -69,7 +70,7 @@ namespace Tensorflow.Keras.Layers | |||
| built = true; | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| Tensor outputs = null; | |||
| var rank = inputs.rank; | |||
| @@ -7,6 +7,7 @@ using System.Text.RegularExpressions; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.ArgsDefinition.Core; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| @@ -189,7 +190,7 @@ namespace Tensorflow.Keras.Layers | |||
| // return new dict(base_config.items().ToList() + config.items().ToList()); | |||
| //} | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| var ret = tf.linalg.einsum(this.equation, (inputs, this.kernel.AsTensor())); | |||
| if (this.bias != null) | |||
| @@ -15,6 +15,7 @@ | |||
| ******************************************************************************/ | |||
| using System.Linq; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving; | |||
| @@ -66,7 +67,7 @@ namespace Tensorflow.Keras.Layers | |||
| _buildInputShape = input_shape; | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| var dtype = inputs.dtype; | |||
| if (dtype != tf.int32 && dtype != tf.int64) | |||
| @@ -5,6 +5,7 @@ using static Tensorflow.Binding; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| @@ -21,7 +22,7 @@ namespace Tensorflow.Keras.Layers | |||
| _buildInputShape = input_shape; | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| return _merge_function(inputs); | |||
| } | |||
| @@ -17,6 +17,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving; | |||
| @@ -146,7 +147,7 @@ namespace Tensorflow.Keras.Layers | |||
| return false; | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| Tensor outputs = null; | |||
| var training_tensor = training == null | |||
| @@ -17,6 +17,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving; | |||
| @@ -101,7 +102,7 @@ namespace Tensorflow.Keras.Layers | |||
| return input_shape; | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| Tensor outputs = null; | |||
| var inputs_dtype = inputs.dtype.as_base_dtype(); | |||
| @@ -14,6 +14,7 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Saving; | |||
| @@ -157,7 +158,7 @@ namespace Tensorflow.Keras.Layers | |||
| base.adapt(data, batch_size: batch_size, steps: steps); | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| if (_args.Invert) | |||
| { | |||
| @@ -2,6 +2,7 @@ | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| @@ -12,7 +13,7 @@ namespace Tensorflow.Keras.Layers | |||
| { | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| if (data_format == "channels_last") | |||
| return math_ops.reduce_mean(inputs, 1, false); | |||
| @@ -2,6 +2,7 @@ | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| @@ -12,7 +13,7 @@ namespace Tensorflow.Keras.Layers | |||
| { | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| if (data_format == "channels_last") | |||
| return math_ops.reduce_mean(inputs, (1, 2), false); | |||
| @@ -2,6 +2,7 @@ | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| @@ -12,7 +13,7 @@ namespace Tensorflow.Keras.Layers | |||
| { | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| if (data_format == "channels_last") | |||
| return math_ops.reduce_max(inputs, 1, false); | |||
| @@ -2,6 +2,7 @@ | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| @@ -12,7 +13,7 @@ namespace Tensorflow.Keras.Layers | |||
| { | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| if (data_format == "channels_last") | |||
| return math_ops.reduce_max(inputs, (1, 2), false); | |||
| @@ -18,6 +18,7 @@ using System.Linq; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Utils; | |||
| using Tensorflow.Common.Types; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.Layers | |||
| @@ -36,7 +37,7 @@ namespace Tensorflow.Keras.Layers | |||
| input_spec = new InputSpec(ndim: 3); | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| int pad_axis = args.DataFormat == "channels_first" ? 2 : 3; | |||
| inputs = tf.expand_dims(inputs, pad_axis); | |||
| @@ -17,6 +17,7 @@ | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Utils; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| @@ -36,7 +37,7 @@ namespace Tensorflow.Keras.Layers | |||
| input_spec = new InputSpec(ndim: 4); | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| int[] pool_shape; | |||
| int[] strides; | |||
| @@ -1,6 +1,6 @@ | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| /// <summary> | |||
| @@ -15,7 +15,7 @@ namespace Tensorflow.Keras.Layers | |||
| this.args = args; | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| var depth = args.NumTokens; | |||
| var max_value = tf.reduce_max(inputs); | |||
| @@ -1,5 +1,6 @@ | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| @@ -17,7 +18,7 @@ namespace Tensorflow.Keras.Layers | |||
| this.args = args; | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| scale = constant_op.constant(args.Scale, args.DType); | |||
| offset = constant_op.constant(args.Offset, args.DType); | |||
| @@ -4,6 +4,7 @@ using System; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| @@ -19,7 +20,7 @@ namespace Tensorflow.Keras.Layers | |||
| this.args = args; | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| return image_ops_impl.resize_images_v2(inputs, new[] { args.Height, args.Width }, method: args.Interpolation); | |||
| } | |||
| @@ -1,4 +1,5 @@ | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Utils; | |||
| using static Tensorflow.Binding; | |||
| @@ -15,7 +16,7 @@ namespace Tensorflow.Keras.Layers | |||
| this.args = args; | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| if (training == null) | |||
| training = false; | |||
| @@ -1,6 +1,8 @@ | |||
| using Tensorflow.Keras.ArgsDefinition.Reshaping; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Layers.Reshaping | |||
| { | |||
| @@ -27,7 +29,7 @@ namespace Tensorflow.Keras.Layers.Reshaping | |||
| _buildInputShape = input_shape; | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| Tensor output = inputs; | |||
| if (output.rank != 3) | |||
| @@ -1,6 +1,7 @@ | |||
| using Tensorflow.Keras.ArgsDefinition.Reshaping; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Layers.Reshaping | |||
| { | |||
| @@ -21,7 +22,7 @@ namespace Tensorflow.Keras.Layers.Reshaping | |||
| built = true; | |||
| _buildInputShape = input_shape; | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| Tensor output = inputs; | |||
| if (output.rank != 4) | |||
| @@ -1,6 +1,7 @@ | |||
| using Tensorflow.Keras.ArgsDefinition.Reshaping; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Layers.Reshaping | |||
| { | |||
| @@ -21,7 +22,7 @@ namespace Tensorflow.Keras.Layers.Reshaping | |||
| _buildInputShape = input_shape; | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| Tensor output = inputs; | |||
| if (output.rank != 5) | |||
| @@ -1,5 +1,6 @@ | |||
| using System; | |||
| using System.Linq; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Framework; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| @@ -23,7 +24,7 @@ namespace Tensorflow.Keras.Layers | |||
| _channels_first = args.DataFormat == "channels_first"; | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| if (_channels_first) | |||
| { | |||
| @@ -6,6 +6,7 @@ using Tensorflow.Keras.Utils; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Layers { | |||
| public class Permute : Layer | |||
| @@ -28,7 +29,7 @@ namespace Tensorflow.Keras.Layers { | |||
| built = true; | |||
| _buildInputShape = input_shape; | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| Tensor outputs = inputs; | |||
| return tf.transpose(outputs, new Axis(permute)); | |||
| @@ -4,6 +4,7 @@ using static Tensorflow.Binding; | |||
| using System.Collections.Generic; | |||
| using System; | |||
| using System.Linq; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| @@ -19,7 +20,7 @@ namespace Tensorflow.Keras.Layers | |||
| this.args = args; | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| var shapes = new List<Tensor>(); | |||
| shapes.Add(array_ops.shape(inputs)[0]); | |||
| @@ -6,6 +6,7 @@ using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Utils; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.KerasApi; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| @@ -24,7 +25,7 @@ namespace Tensorflow.Keras.Layers | |||
| inputSpec = new InputSpec(ndim: 4); | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| return keras.backend.resize_images(inputs, | |||
| size[0], size[1], | |||
| @@ -2,6 +2,7 @@ | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Utils; | |||
| using Tensorflow.Common.Types; | |||
| using static Tensorflow.KerasApi; | |||
| namespace Tensorflow.Keras.Layers | |||
| @@ -26,7 +27,7 @@ namespace Tensorflow.Keras.Layers | |||
| this.input_spec = new InputSpec(ndim: 4); | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| return keras.backend.spatial_2d_padding(inputs, | |||
| padding: padding, | |||
| @@ -0,0 +1,85 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| namespace Tensorflow.Keras.Layers.Rnn | |||
| { | |||
| public abstract class DropoutRNNCellMixin: RnnCellBase | |||
| { | |||
| public float dropout; | |||
| public float recurrent_dropout; | |||
| // TODO(Rinne): deal with cache. | |||
| public DropoutRNNCellMixin(LayerArgs args): base(args) | |||
| { | |||
| } | |||
| public Tensors? get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) | |||
| { | |||
| if (dropout == 0f) | |||
| return null; | |||
| return _generate_dropout_mask( | |||
| tf.ones_like(input), | |||
| dropout, | |||
| training, | |||
| count); | |||
| } | |||
| // Get the recurrent dropout mask for RNN cell. | |||
| public Tensors? get_recurrent_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) | |||
| { | |||
| if (dropout == 0f) | |||
| return null; | |||
| return _generate_dropout_mask( | |||
| tf.ones_like(input), | |||
| recurrent_dropout, | |||
| training, | |||
| count); | |||
| } | |||
| public Tensors _create_dropout_mask(Tensors input, bool training, int count = 1) | |||
| { | |||
| return _generate_dropout_mask( | |||
| tf.ones_like(input), | |||
| dropout, | |||
| training, | |||
| count); | |||
| } | |||
| public Tensors _create_recurrent_dropout_mask(Tensors input, bool training, int count = 1) | |||
| { | |||
| return _generate_dropout_mask( | |||
| tf.ones_like(input), | |||
| recurrent_dropout, | |||
| training, | |||
| count); | |||
| } | |||
| public Tensors _generate_dropout_mask(Tensor ones, float rate, bool training, int count = 1) | |||
| { | |||
| Tensors dropped_inputs() | |||
| { | |||
| DropoutArgs args = new DropoutArgs(); | |||
| args.Rate = rate; | |||
| var DropoutLayer = new Dropout(args); | |||
| var mask = DropoutLayer.Apply(ones, training: training); | |||
| return mask; | |||
| } | |||
| if (count > 1) | |||
| { | |||
| Tensors results = new Tensors(); | |||
| for (int i = 0; i < count; i++) | |||
| { | |||
| results.Add(dropped_inputs()); | |||
| } | |||
| return results; | |||
| } | |||
| return dropped_inputs(); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,6 +1,7 @@ | |||
| using System.Linq; | |||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Layers.Rnn | |||
| { | |||
| @@ -26,9 +27,9 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
| .ToArray(); | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| return base.Call(inputs, state: state, training: training); | |||
| return base.Call(inputs, initial_state: state, training: training); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,53 +1,468 @@ | |||
| using System; | |||
| using OneOf; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Reflection; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.Util; | |||
| using Tensorflow.Common.Extensions; | |||
| using System.Linq.Expressions; | |||
| using Tensorflow.Keras.Utils; | |||
| using Tensorflow.Common.Types; | |||
| // from tensorflow.python.distribute import distribution_strategy_context as ds_context; | |||
| namespace Tensorflow.Keras.Layers.Rnn | |||
| { | |||
| public class RNN : Layer | |||
| /// <summary> | |||
| /// Base class for recurrent layers. | |||
| /// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) | |||
| /// for details about the usage of RNN API. | |||
| /// </summary> | |||
| public class RNN : RnnBase | |||
| { | |||
| private RNNArgs args; | |||
| private object input_spec = null; // or NoneValue?? | |||
| private object state_spec = null; | |||
| private object _states = null; | |||
| private object constants_spec = null; | |||
| private int _num_constants = 0; | |||
| protected IVariableV1 kernel; | |||
| protected IVariableV1 bias; | |||
| protected ILayer cell; | |||
| private RNNArgs _args; | |||
| private object _input_spec = null; // or NoneValue?? | |||
| private object _state_spec = null; | |||
| private Tensors _states = null; | |||
| private object _constants_spec = null; | |||
| private int _num_constants; | |||
| protected IVariableV1 _kernel; | |||
| protected IVariableV1 _bias; | |||
| protected IRnnCell _cell; | |||
| public RNN(RNNArgs args) : base(PreConstruct(args)) | |||
| { | |||
| this.args = args; | |||
| _args = args; | |||
| SupportsMasking = true; | |||
| // The input shape is unknown yet, it could have nested tensor inputs, and | |||
| // the input spec will be the list of specs for nested inputs, the structure | |||
| // of the input_spec will be the same as the input. | |||
| // if is StackedRnncell | |||
| _cell = args.Cell; | |||
| // get input_shape | |||
| _args = PreConstruct(args); | |||
| _num_constants = 0; | |||
| } | |||
| // States is a tuple consist of cell states_size, like (cell1.state_size, cell2.state_size,...) | |||
| // state_size can be a single integer, can also be a list/tuple of integers, can also be TensorShape or a list/tuple of TensorShape | |||
| public Tensors States | |||
| { | |||
| get | |||
| { | |||
| if (_states == null) | |||
| { | |||
| // CHECK(Rinne): check if this is correct. | |||
| var nested = _cell.StateSize.MapStructure<Tensor?>(x => null); | |||
| _states = nested.AsNest().ToTensors(); | |||
| } | |||
| return _states; | |||
| } | |||
| set { _states = value; } | |||
| } | |||
| private OneOf<Shape, List<Shape>> compute_output_shape(Shape input_shape) | |||
| { | |||
| var batch = input_shape[0]; | |||
| var time_step = input_shape[1]; | |||
| if (_args.TimeMajor) | |||
| { | |||
| (batch, time_step) = (time_step, batch); | |||
| } | |||
| // state_size is a array of ints or a positive integer | |||
| var state_size = _cell.StateSize.ToSingleShape(); | |||
| // TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor | |||
| Func<Shape, Shape> _get_output_shape; | |||
| _get_output_shape = (flat_output_size) => | |||
| { | |||
| var output_dim = flat_output_size.as_int_list(); | |||
| Shape output_shape; | |||
| if (_args.ReturnSequences) | |||
| { | |||
| if (_args.TimeMajor) | |||
| { | |||
| output_shape = new Shape(new int[] { (int)time_step, (int)batch }.concat(output_dim)); | |||
| } | |||
| else | |||
| { | |||
| output_shape = new Shape(new int[] { (int)batch, (int)time_step }.concat(output_dim)); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| output_shape = new Shape(new int[] { (int)batch }.concat(output_dim)); | |||
| } | |||
| return output_shape; | |||
| }; | |||
| Type type = _cell.GetType(); | |||
| PropertyInfo output_size_info = type.GetProperty("output_size"); | |||
| Shape output_shape; | |||
| if (output_size_info != null) | |||
| { | |||
| output_shape = nest.map_structure(_get_output_shape, _cell.OutputSize.ToSingleShape()); | |||
| // TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型 | |||
| output_shape = (output_shape.Length == 1 ? (int)output_shape[0] : output_shape); | |||
| } | |||
| else | |||
| { | |||
| output_shape = _get_output_shape(state_size); | |||
| } | |||
| if (_args.ReturnState) | |||
| { | |||
| Func<Shape, Shape> _get_state_shape; | |||
| _get_state_shape = (flat_state) => | |||
| { | |||
| var state_shape = new int[] { (int)batch }.concat(flat_state.as_int_list()); | |||
| return new Shape(state_shape); | |||
| }; | |||
| var state_shape = _get_state_shape(state_size); | |||
| return new List<Shape> { output_shape, state_shape }; | |||
| } | |||
| else | |||
| { | |||
| return output_shape; | |||
| } | |||
| //if(stateful) | |||
| //{ | |||
| // if (ds_context.has_strategy()) // ds_context???? | |||
| // { | |||
| // throw new Exception("RNNs with stateful=True not yet supported with tf.distribute.Strategy"); | |||
| // } | |||
| //} | |||
| } | |||
| private Tensors compute_mask(Tensors inputs, Tensors mask) | |||
| { | |||
| // Time step masks must be the same for each input. | |||
| // This is because the mask for an RNN is of size [batch, time_steps, 1], | |||
| // and specifies which time steps should be skipped, and a time step | |||
| // must be skipped for all inputs. | |||
| mask = nest.flatten(mask)[0]; | |||
| var output_mask = _args.ReturnSequences ? mask : null; | |||
| if (_args.ReturnState) | |||
| { | |||
| var state_mask = new List<Tensor>(); | |||
| for (int i = 0; i < len(States); i++) | |||
| { | |||
| state_mask.Add(null); | |||
| } | |||
| return new List<Tensor> { output_mask }.concat(state_mask); | |||
| } | |||
| else | |||
| { | |||
| return output_mask; | |||
| } | |||
| } | |||
| public override void build(KerasShapesWrapper input_shape) | |||
| { | |||
| if (!cell.Built) | |||
| object get_input_spec(Shape shape) | |||
| { | |||
| var input_spec_shape = shape.as_int_list(); | |||
| var (batch_index, time_step_index) = _args.TimeMajor ? (1, 0) : (0, 1); | |||
| if (!_args.Stateful) | |||
| { | |||
| input_spec_shape[batch_index] = -1; | |||
| } | |||
| input_spec_shape[time_step_index] = -1; | |||
| return new InputSpec(shape: input_spec_shape); | |||
| } | |||
| Shape get_step_input_shape(Shape shape) | |||
| { | |||
| // return shape[1:] if self.time_major else (shape[0],) + shape[2:] | |||
| if (_args.TimeMajor) | |||
| { | |||
| return shape.as_int_list().ToList().GetRange(1, shape.Length - 1).ToArray(); | |||
| } | |||
| else | |||
| { | |||
| return new int[] { shape.as_int_list()[0] }.concat(shape.as_int_list().ToList().GetRange(2, shape.Length - 2).ToArray()); | |||
| } | |||
| } | |||
| object get_state_spec(Shape shape) | |||
| { | |||
| var state_spec_shape = shape.as_int_list(); | |||
| // append bacth dim | |||
| state_spec_shape = new int[] { -1 }.concat(state_spec_shape); | |||
| return new InputSpec(shape: state_spec_shape); | |||
| } | |||
| // Check whether the input shape contains any nested shapes. It could be | |||
| // (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from | |||
| // numpy inputs. | |||
| if (!_cell.Built) | |||
| { | |||
| _cell.build(input_shape); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <param name="inputs"></param> | |||
| /// <param name="mask">Binary tensor of shape [batch_size, timesteps] indicating whether a given timestep should be masked</param> | |||
| /// <param name="training"></param> | |||
| /// <param name="initial_state">List of initial state tensors to be passed to the first call of the cell</param> | |||
| /// <param name="constants">List of constant tensors to be passed to the cell at each timestep</param> | |||
| /// <returns></returns> | |||
| /// <exception cref="ValueError"></exception> | |||
| /// <exception cref="NotImplementedException"></exception> | |||
| protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs; | |||
| if(optional_args is not null && rnn_optional_args is null) | |||
| { | |||
| throw new ArgumentException("The optional args shhould be of type `RnnOptionalArgs`"); | |||
| } | |||
| Tensors? constants = rnn_optional_args?.Constants; | |||
| Tensors? mask = rnn_optional_args?.Mask; | |||
| //var (inputs_padded, row_length) = BackendImpl.convert_inputs_if_ragged(inputs); | |||
| // 暂时先不接受ragged tensor | |||
| int row_length = 0; // TODO(Rinne): support this param. | |||
| bool is_ragged_input = false; | |||
| _validate_args_if_ragged(is_ragged_input, mask); | |||
| (inputs, initial_state, constants) = _process_inputs(inputs, initial_state, constants); | |||
| _maybe_reset_cell_dropout_mask(_cell); | |||
| if (_cell is StackedRNNCells) | |||
| { | |||
| var stack_cell = _cell as StackedRNNCells; | |||
| foreach (var cell in stack_cell.Cells) | |||
| { | |||
| _maybe_reset_cell_dropout_mask(cell); | |||
| } | |||
| } | |||
| if (mask != null) | |||
| { | |||
| // Time step masks must be the same for each input. | |||
| mask = mask.Flatten().First(); | |||
| } | |||
| Shape input_shape; | |||
| if (!inputs.IsSingle()) | |||
| { | |||
| // In the case of nested input, use the first element for shape check | |||
| // input_shape = nest.flatten(inputs)[0].shape; | |||
| // TODO(Wanglongzhi2001) | |||
| input_shape = inputs.Flatten().First().shape; | |||
| } | |||
| else | |||
| { | |||
| input_shape = inputs.shape; | |||
| } | |||
| var timesteps = _args.TimeMajor ? input_shape[0] : input_shape[1]; | |||
| if (_args.Unroll && timesteps != null) | |||
| { | |||
| throw new ValueError( | |||
| "Cannot unroll a RNN if the " + | |||
| "time dimension is undefined. \n" + | |||
| "- If using a Sequential model, " + | |||
| "specify the time dimension by passing " + | |||
| "an `input_shape` or `batch_input_shape` " + | |||
| "argument to your first layer. If your " + | |||
| "first layer is an Embedding, you can " + | |||
| "also use the `input_length` argument.\n" + | |||
| "- If using the functional API, specify " + | |||
| "the time dimension by passing a `shape` " + | |||
| "or `batch_shape` argument to your Input layer." | |||
| ); | |||
| } | |||
| // cell_call_fn = (self.cell.__call__ if callable(self.cell) else self.cell.call) | |||
| Func<Tensors, Tensors, (Tensors, Tensors)> step; | |||
| bool is_tf_rnn_cell = _cell.IsTFRnnCell; | |||
| if (constants is not null) | |||
| { | |||
| if (!_cell.SupportOptionalArgs) | |||
| { | |||
| throw new ValueError( | |||
| $"RNN cell {_cell} does not support constants." + | |||
| $"Received: constants={constants}"); | |||
| } | |||
| step = (inputs, states) => | |||
| { | |||
| constants = new Tensors(states.TakeLast(_num_constants)); | |||
| states = new Tensors(states.SkipLast(_num_constants)); | |||
| states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states; | |||
| var (output, new_states) = _cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants }); | |||
| // TODO(Wanglongzhi2001),should cell_call_fn's return value be Tensors, Tensors? | |||
| return (output, new_states.Single); | |||
| }; | |||
| } | |||
| else | |||
| { | |||
| step = (inputs, states) => | |||
| { | |||
| states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states; | |||
| var (output, new_states) = _cell.Apply(inputs, states); | |||
| return (output, new_states.Single); | |||
| }; | |||
| } | |||
| var (last_output, outputs, states) = keras.backend.rnn(step, | |||
| inputs, | |||
| initial_state, | |||
| constants: constants, | |||
| go_backwards: _args.GoBackwards, | |||
| mask: mask, | |||
| unroll: _args.Unroll, | |||
| input_length: row_length != null ? new Tensor(row_length) : new Tensor(timesteps), | |||
| time_major: _args.TimeMajor, | |||
| zero_output_for_mask: _args.ZeroOutputForMask, | |||
| return_all_outputs: _args.ReturnSequences); | |||
| if (_args.Stateful) | |||
| { | |||
| throw new NotImplementedException("this argument havn't been developed."); | |||
| } | |||
| Tensors output = new Tensors(); | |||
| if (_args.ReturnSequences) | |||
| { | |||
| // TODO(Rinne): add go_backwards parameter and revise the `row_length` param | |||
| output = keras.backend.maybe_convert_to_ragged(is_ragged_input, outputs, row_length, false); | |||
| } | |||
| else | |||
| { | |||
| output = last_output; | |||
| } | |||
| if (_args.ReturnState) | |||
| { | |||
| foreach (var state in states) | |||
| { | |||
| output.Add(state); | |||
| } | |||
| return output; | |||
| } | |||
| else | |||
| { | |||
| return output; | |||
| } | |||
| } | |||
| public override Tensors Apply(Tensors inputs, Tensors initial_states = null, bool training = false, IOptionalArgs? optional_args = null) | |||
| { | |||
| RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs; | |||
| if (optional_args is not null && rnn_optional_args is null) | |||
| { | |||
| cell.build(input_shape); | |||
| throw new ArgumentException("The type of optional args should be `RnnOptionalArgs`."); | |||
| } | |||
| Tensors? constants = rnn_optional_args?.Constants; | |||
| (inputs, initial_states, constants) = RnnUtils.standardize_args(inputs, initial_states, constants, _num_constants); | |||
| if(initial_states is null && constants is null) | |||
| { | |||
| return base.Apply(inputs); | |||
| } | |||
| // TODO(Rinne): implement it. | |||
| throw new NotImplementedException(); | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| private (Tensors inputs, Tensors initial_state, Tensors constants) _process_inputs(Tensors inputs, Tensors initial_state, Tensors constants) | |||
| { | |||
| return base.Call(inputs, state, training); | |||
| if (inputs.Length > 1) | |||
| { | |||
| if (_num_constants != 0) | |||
| { | |||
| initial_state = new Tensors(inputs.Skip(1)); | |||
| } | |||
| else | |||
| { | |||
| initial_state = new Tensors(inputs.Skip(1).SkipLast(_num_constants)); | |||
| constants = new Tensors(inputs.TakeLast(_num_constants)); | |||
| } | |||
| if (len(initial_state) == 0) | |||
| initial_state = null; | |||
| inputs = inputs[0]; | |||
| } | |||
| if (_args.Stateful) | |||
| { | |||
| if (initial_state != null) | |||
| { | |||
| var tmp = new Tensor[] { }; | |||
| foreach (var s in nest.flatten(States)) | |||
| { | |||
| tmp.add(tf.math.count_nonzero((Tensor)s)); | |||
| } | |||
| var non_zero_count = tf.add_n(tmp); | |||
| //initial_state = tf.cond(non_zero_count > 0, () => States, () => initial_state); | |||
| if ((int)non_zero_count.numpy() > 0) | |||
| { | |||
| initial_state = States; | |||
| } | |||
| } | |||
| else | |||
| { | |||
| initial_state = States; | |||
| } | |||
| } | |||
| else if (initial_state is null) | |||
| { | |||
| initial_state = get_initial_state(inputs); | |||
| } | |||
| if (initial_state.Length != States.Length) | |||
| { | |||
| throw new ValueError( | |||
| $"Layer {this} expects {States.Length} state(s), " + | |||
| $"but it received {initial_state.Length} " + | |||
| $"initial state(s). Input received: {inputs}"); | |||
| } | |||
| return (inputs, initial_state, constants); | |||
| } | |||
| private void _validate_args_if_ragged(bool is_ragged_input, Tensors mask) | |||
| { | |||
| if (!is_ragged_input) | |||
| { | |||
| return; | |||
| } | |||
| if (_args.Unroll) | |||
| { | |||
| throw new ValueError("The input received contains RaggedTensors and does " + | |||
| "not support unrolling. Disable unrolling by passing " + | |||
| "`unroll=False` in the RNN Layer constructor."); | |||
| } | |||
| if (mask != null) | |||
| { | |||
| throw new ValueError($"The mask that was passed in was {mask}, which " + | |||
| "cannot be applied to RaggedTensor inputs. Please " + | |||
| "make sure that there is no mask injected by upstream " + | |||
| "layers."); | |||
| } | |||
| } | |||
| void _maybe_reset_cell_dropout_mask(ILayer cell) | |||
| { | |||
| //if (cell is DropoutRNNCellMixin) | |||
| //{ | |||
| // cell.reset_dropout_mask(); | |||
| // cell.reset_recurrent_dropout_mask(); | |||
| //} | |||
| } | |||
| private static RNNArgs PreConstruct(RNNArgs args) | |||
| @@ -77,60 +492,72 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
| return args; | |||
| } | |||
| public RNN New(LayerRnnCell cell, | |||
| bool return_sequences = false, | |||
| bool return_state = false, | |||
| bool go_backwards = false, | |||
| bool stateful = false, | |||
| bool unroll = false, | |||
| bool time_major = false) | |||
| => new RNN(new RNNArgs | |||
| { | |||
| Cell = cell, | |||
| ReturnSequences = return_sequences, | |||
| ReturnState = return_state, | |||
| GoBackwards = go_backwards, | |||
| Stateful = stateful, | |||
| Unroll = unroll, | |||
| TimeMajor = time_major | |||
| }); | |||
| public RNN New(IList<RnnCell> cell, | |||
| bool return_sequences = false, | |||
| bool return_state = false, | |||
| bool go_backwards = false, | |||
| bool stateful = false, | |||
| bool unroll = false, | |||
| bool time_major = false) | |||
| => new RNN(new RNNArgs | |||
| { | |||
| Cell = new StackedRNNCells(new StackedRNNCellsArgs { Cells = cell }), | |||
| ReturnSequences = return_sequences, | |||
| ReturnState = return_state, | |||
| GoBackwards = go_backwards, | |||
| Stateful = stateful, | |||
| Unroll = unroll, | |||
| TimeMajor = time_major | |||
| }); | |||
| protected Tensor get_initial_state(Tensor inputs) | |||
| public Tensors __call__(Tensors inputs, Tensor state = null, Tensor training = null) | |||
| { | |||
| return _generate_zero_filled_state_for_cell(null, null); | |||
| throw new NotImplementedException(); | |||
| } | |||
| Tensor _generate_zero_filled_state_for_cell(LSTMCell cell, Tensor batch_size) | |||
| // 好像不能cell不能传接口类型 | |||
| //public RNN New(IRnnArgCell cell, | |||
| // bool return_sequences = false, | |||
| // bool return_state = false, | |||
| // bool go_backwards = false, | |||
| // bool stateful = false, | |||
| // bool unroll = false, | |||
| // bool time_major = false) | |||
| // => new RNN(new RNNArgs | |||
| // { | |||
| // Cell = cell, | |||
| // ReturnSequences = return_sequences, | |||
| // ReturnState = return_state, | |||
| // GoBackwards = go_backwards, | |||
| // Stateful = stateful, | |||
| // Unroll = unroll, | |||
| // TimeMajor = time_major | |||
| // }); | |||
| //public RNN New(List<IRnnArgCell> cell, | |||
| // bool return_sequences = false, | |||
| // bool return_state = false, | |||
| // bool go_backwards = false, | |||
| // bool stateful = false, | |||
| // bool unroll = false, | |||
| // bool time_major = false) | |||
| // => new RNN(new RNNArgs | |||
| // { | |||
| // Cell = cell, | |||
| // ReturnSequences = return_sequences, | |||
| // ReturnState = return_state, | |||
| // GoBackwards = go_backwards, | |||
| // Stateful = stateful, | |||
| // Unroll = unroll, | |||
| // TimeMajor = time_major | |||
| // }); | |||
| protected Tensors get_initial_state(Tensors inputs) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| var input = inputs[0]; | |||
| var input_shape = input.shape; | |||
| var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0]; | |||
| var dtype = input.dtype; | |||
| Tensors init_state; | |||
| if (_cell is RnnCellBase rnn_base_cell) | |||
| { | |||
| init_state = rnn_base_cell.GetInitialState(null, batch_size, dtype); | |||
| } | |||
| else | |||
| { | |||
| init_state = RnnUtils.generate_zero_filled_state(batch_size, _cell.StateSize, dtype); | |||
| } | |||
| return init_state; | |||
| } | |||
| // Check whether the state_size contains multiple states. | |||
| public static bool _is_multiple_state(object state_size) | |||
| public static bool is_multiple_state(GeneralizedTensorShape state_size) | |||
| { | |||
| var myIndexerProperty = state_size.GetType().GetProperty("Item"); | |||
| return myIndexerProperty != null | |||
| && myIndexerProperty.GetIndexParameters().Length == 1 | |||
| && !(state_size.GetType() == typeof(Shape)); | |||
| return state_size.Shapes.Length > 1; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,13 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| namespace Tensorflow.Keras.Layers.Rnn | |||
| { | |||
| public abstract class RnnBase: Layer | |||
| { | |||
| public RnnBase(LayerArgs args): base(args) { } | |||
| } | |||
| } | |||
| @@ -0,0 +1,24 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Utils; | |||
| namespace Tensorflow.Keras.Layers.Rnn | |||
| { | |||
| public abstract class RnnCellBase: Layer, IRnnCell | |||
| { | |||
| public RnnCellBase(LayerArgs args) : base(args) { } | |||
| public abstract GeneralizedTensorShape StateSize { get; } | |||
| public abstract GeneralizedTensorShape OutputSize { get; } | |||
| public abstract bool IsTFRnnCell { get; } | |||
| public abstract bool SupportOptionalArgs { get; } | |||
| public virtual Tensors GetInitialState(Tensors inputs, long batch_size, TF_DataType dtype) | |||
| { | |||
| return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size, dtype); | |||
| } | |||
| } | |||
| } | |||
| @@ -10,18 +10,36 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
| public class SimpleRNN : RNN | |||
| { | |||
| SimpleRNNArgs args; | |||
| public SimpleRNN(SimpleRNNArgs args) : base(args) | |||
| public SimpleRNN(SimpleRNNArgs args) : base(CreateCellForArgs(args)) | |||
| { | |||
| this.args = args; | |||
| } | |||
| private static SimpleRNNArgs CreateCellForArgs(SimpleRNNArgs args) | |||
| { | |||
| args.Cell = new SimpleRNNCell(new SimpleRNNCellArgs() | |||
| { | |||
| Units = args.Units, | |||
| Activation = args.Activation, | |||
| UseBias = args.UseBias, | |||
| KernelInitializer = args.KernelInitializer, | |||
| RecurrentInitializer = args.RecurrentInitializer, | |||
| BiasInitializer = args.BiasInitializer, | |||
| Dropout = args.Dropout, | |||
| RecurrentDropout = args.RecurrentDropout, | |||
| DType = args.DType, | |||
| Trainable = args.Trainable, | |||
| }); | |||
| return args; | |||
| } | |||
| public override void build(KerasShapesWrapper input_shape) | |||
| { | |||
| var single_shape = input_shape.ToSingleShape(); | |||
| var input_dim = single_shape[-1]; | |||
| _buildInputShape = input_shape; | |||
| kernel = add_weight("kernel", (single_shape[-1], args.Units), | |||
| _kernel = add_weight("kernel", (single_shape[-1], args.Units), | |||
| initializer: args.KernelInitializer | |||
| //regularizer = self.kernel_regularizer, | |||
| //constraint = self.kernel_constraint, | |||
| @@ -4,47 +4,114 @@ using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Common.Extensions; | |||
| namespace Tensorflow.Keras.Layers.Rnn | |||
| { | |||
| public class SimpleRNNCell : Layer | |||
| /// <summary> | |||
| /// Cell class for SimpleRNN. | |||
| /// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) | |||
| /// for details about the usage of RNN API. | |||
| /// This class processes one step within the whole time sequence input, whereas | |||
| /// `tf.keras.layer.SimpleRNN` processes the whole sequence. | |||
| /// </summary> | |||
| public class SimpleRNNCell : DropoutRNNCellMixin | |||
| { | |||
| SimpleRNNArgs args; | |||
| IVariableV1 kernel; | |||
| IVariableV1 recurrent_kernel; | |||
| IVariableV1 bias; | |||
| SimpleRNNCellArgs _args; | |||
| IVariableV1 _kernel; | |||
| IVariableV1 _recurrent_kernel; | |||
| IVariableV1 _bias; | |||
| GeneralizedTensorShape _state_size; | |||
| GeneralizedTensorShape _output_size; | |||
| public SimpleRNNCell(SimpleRNNArgs args) : base(args) | |||
| public override GeneralizedTensorShape StateSize => _state_size; | |||
| public override GeneralizedTensorShape OutputSize => _output_size; | |||
| public override bool IsTFRnnCell => true; | |||
| public override bool SupportOptionalArgs => false; | |||
| public SimpleRNNCell(SimpleRNNCellArgs args) : base(args) | |||
| { | |||
| this.args = args; | |||
| this._args = args; | |||
| if (args.Units <= 0) | |||
| { | |||
| throw new ValueError( | |||
| $"units must be a positive integer, got {args.Units}"); | |||
| } | |||
| this._args.Dropout = Math.Min(1f, Math.Max(0f, this._args.Dropout)); | |||
| this._args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout)); | |||
| _state_size = new GeneralizedTensorShape(args.Units); | |||
| _output_size = new GeneralizedTensorShape(args.Units); | |||
| } | |||
| public override void build(KerasShapesWrapper input_shape) | |||
| { | |||
| // TODO(Rinne): add the cache. | |||
| var single_shape = input_shape.ToSingleShape(); | |||
| var input_dim = single_shape[-1]; | |||
| kernel = add_weight("kernel", (single_shape[-1], args.Units), | |||
| initializer: args.KernelInitializer | |||
| _kernel = add_weight("kernel", (single_shape[-1], _args.Units), | |||
| initializer: _args.KernelInitializer | |||
| ); | |||
| recurrent_kernel = add_weight("recurrent_kernel", (args.Units, args.Units), | |||
| initializer: args.RecurrentInitializer | |||
| _recurrent_kernel = add_weight("recurrent_kernel", (_args.Units, _args.Units), | |||
| initializer: _args.RecurrentInitializer | |||
| ); | |||
| if (args.UseBias) | |||
| if (_args.UseBias) | |||
| { | |||
| bias = add_weight("bias", (args.Units), | |||
| initializer: args.BiasInitializer | |||
| _bias = add_weight("bias", (_args.Units), | |||
| initializer: _args.BiasInitializer | |||
| ); | |||
| } | |||
| built = true; | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| // TODO(Rinne): revise the trining param (with refactoring of the framework) | |||
| protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| return base.Call(inputs, state, training); | |||
| // TODO(Rinne): check if it will have multiple tensors when not nested. | |||
| Tensors prev_output = Nest.IsNested(states) ? new Tensors(states[0]) : states; | |||
| var dp_mask = get_dropout_maskcell_for_cell(inputs, training.Value); | |||
| var rec_dp_mask = get_recurrent_dropout_maskcell_for_cell(prev_output, training.Value); | |||
| Tensor h; | |||
| if (dp_mask != null) | |||
| { | |||
| h = math_ops.matmul(math_ops.multiply(inputs.Single, dp_mask.Single), _kernel.AsTensor()); | |||
| } | |||
| else | |||
| { | |||
| h = math_ops.matmul(inputs, _kernel.AsTensor()); | |||
| } | |||
| if (_bias != null) | |||
| { | |||
| h = tf.nn.bias_add(h, _bias); | |||
| } | |||
| if (rec_dp_mask != null) | |||
| { | |||
| prev_output = math_ops.multiply(prev_output, rec_dp_mask); | |||
| } | |||
| Tensor output = h + math_ops.matmul(prev_output, _recurrent_kernel.AsTensor()); | |||
| if (_args.Activation != null) | |||
| { | |||
| output = _args.Activation.Apply(output); | |||
| } | |||
| if (Nest.IsNested(states)) | |||
| { | |||
| return new Nest<Tensor>(new List<Nest<Tensor>> { | |||
| new Nest<Tensor>(new List<Nest<Tensor>> { new Nest<Tensor>(output) }), new Nest<Tensor>(output) }) | |||
| .ToTensors(); | |||
| } | |||
| else | |||
| { | |||
| return new Tensors(output, output); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -1,6 +1,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.ComponentModel; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
| using Tensorflow.Keras.Engine; | |||
| @@ -8,7 +9,7 @@ using Tensorflow.Keras.Saving; | |||
| namespace Tensorflow.Keras.Layers.Rnn | |||
| { | |||
| public class StackedRNNCells : Layer, RNNArgs.IRnnArgCell | |||
| public class StackedRNNCells : Layer, IRnnCell | |||
| { | |||
| public IList<RnnCell> Cells { get; set; } | |||
| public bool reverse_state_order; | |||
| @@ -51,7 +52,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
| { | |||
| return lastCell.output_size; | |||
| } | |||
| else if (RNN._is_multiple_state(lastCell.state_size)) | |||
| else if (RNN.is_multiple_state(lastCell.StateSize)) | |||
| { | |||
| // return ((dynamic)Cells[-1].state_size)[0]; | |||
| throw new NotImplementedException(""); | |||
| @@ -162,5 +163,14 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
| // deserialize_layer(cell_config, custom_objects = custom_objects)) | |||
| // return cls(cells, **config) | |||
| } | |||
| public (Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public GeneralizedTensorShape StateSize => throw new NotImplementedException(); | |||
| public GeneralizedTensorShape OutputSize => throw new NotImplementedException(); | |||
| public bool IsTFRnnCell => throw new NotImplementedException(); | |||
| public bool SupportOptionalArgs => throw new NotImplementedException(); | |||
| } | |||
| } | |||
| @@ -10,6 +10,7 @@ using Tensorflow.Keras.Engine; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow.Functions; | |||
| using System.Threading; | |||
| using Tensorflow.Common.Types; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| @@ -34,7 +35,7 @@ namespace Tensorflow.Keras.Layers | |||
| built = true; | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
| { | |||
| if (tf.Context.executing_eagerly()) | |||
| return DeFunCall(inputs); | |||
| @@ -304,7 +304,7 @@ public class metrics_utils | |||
| var NEG_INF = -1e10; | |||
| var (_, top_k_idx) = tf.math.top_k(x, k, sorted: false); | |||
| var top_k_mask = tf.reduce_sum( | |||
| tf.one_hot(top_k_idx, (int)x.shape[-1], axis: -1), axis: -2); | |||
| tf.one_hot(top_k_idx.Single, (int)x.shape[-1], axis: -1), axis: -2); | |||
| return x * top_k_mask + NEG_INF * (1 - top_k_mask); | |||
| } | |||
| } | |||
| @@ -129,7 +129,7 @@ namespace Tensorflow.Keras | |||
| var indices = z.map(m => | |||
| { | |||
| var (i, positions) = m; | |||
| return tf.range(positions[i], positions[i] + sequence_length_tensor * sampling_rate_tensor, sampling_rate_tensor); | |||
| return tf.range(positions.Single[i], positions.Single[i] + sequence_length_tensor * sampling_rate_tensor, sampling_rate_tensor); | |||
| }, num_parallel_calls: -1); | |||
| var dataset = sequences_from_indices(data, indices, start_index, end_index); | |||
| @@ -8,7 +8,7 @@ using System.Diagnostics; | |||
| using System.Linq; | |||
| using System.Reflection; | |||
| using System.Text.RegularExpressions; | |||
| using Tensorflow.Extensions; | |||
| using Tensorflow.Common.Extensions; | |||
| using Tensorflow.Framework.Models; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| @@ -0,0 +1,93 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Text; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.Layers.Rnn; | |||
| using Tensorflow.Common.Extensions; | |||
| namespace Tensorflow.Keras.Utils | |||
| { | |||
| internal static class RnnUtils | |||
| { | |||
| internal static Tensors generate_zero_filled_state(long batch_size_tensor, GeneralizedTensorShape state_size, TF_DataType dtype) | |||
| { | |||
| Func<GeneralizedTensorShape, Tensor> create_zeros; | |||
| create_zeros = (GeneralizedTensorShape unnested_state_size) => | |||
| { | |||
| var flat_dims = unnested_state_size.ToSingleShape().dims; | |||
| var init_state_size = new long[] { batch_size_tensor }.Concat(flat_dims).ToArray(); | |||
| return array_ops.zeros(new Shape(init_state_size), dtype: dtype); | |||
| }; | |||
| // TODO(Rinne): map structure with nested tensors. | |||
| if(state_size.Shapes.Length > 1) | |||
| { | |||
| return new Tensors(state_size.ToShapeArray().Select(s => create_zeros(new GeneralizedTensorShape(s)))); | |||
| } | |||
| else | |||
| { | |||
| return create_zeros(state_size); | |||
| } | |||
| } | |||
| internal static Tensors generate_zero_filled_state_for_cell(IRnnCell cell, Tensors inputs, long batch_size, TF_DataType dtype) | |||
| { | |||
| if (inputs != null) | |||
| { | |||
| batch_size = inputs.shape[0]; | |||
| dtype = inputs.dtype; | |||
| } | |||
| return generate_zero_filled_state(batch_size, cell.StateSize, dtype); | |||
| } | |||
| /// <summary> | |||
| /// Standardizes `__call__` to a single list of tensor inputs. | |||
| /// | |||
| /// When running a model loaded from a file, the input tensors | |||
| /// `initial_state` and `constants` can be passed to `RNN.__call__()` as part | |||
| /// of `inputs` instead of by the dedicated keyword arguments.This method | |||
| /// makes sure the arguments are separated and that `initial_state` and | |||
| /// `constants` are lists of tensors(or None). | |||
| /// </summary> | |||
| /// <param name="inputs">Tensor or list/tuple of tensors. which may include constants | |||
| /// and initial states.In that case `num_constant` must be specified.</param> | |||
| /// <param name="initial_state">Tensor or list of tensors or None, initial states.</param> | |||
| /// <param name="constants">Tensor or list of tensors or None, constant tensors.</param> | |||
| /// <param name="num_constants">Expected number of constants (if constants are passed as | |||
| /// part of the `inputs` list.</param> | |||
| /// <returns></returns> | |||
| internal static (Tensors, Tensors, Tensors) standardize_args(Tensors inputs, Tensors initial_state, Tensors constants, int num_constants) | |||
| { | |||
| if(inputs.Length > 1) | |||
| { | |||
| // There are several situations here: | |||
| // In the graph mode, __call__ will be only called once. The initial_state | |||
| // and constants could be in inputs (from file loading). | |||
| // In the eager mode, __call__ will be called twice, once during | |||
| // rnn_layer(inputs=input_t, constants=c_t, ...), and second time will be | |||
| // model.fit/train_on_batch/predict with real np data. In the second case, | |||
| // the inputs will contain initial_state and constants as eager tensor. | |||
| // | |||
| // For either case, the real input is the first item in the list, which | |||
| // could be a nested structure itself. Then followed by initial_states, which | |||
| // could be a list of items, or list of list if the initial_state is complex | |||
| // structure, and finally followed by constants which is a flat list. | |||
| Debug.Assert(initial_state is null && constants is null); | |||
| if(num_constants > 0) | |||
| { | |||
| constants = inputs.TakeLast(num_constants).ToTensors(); | |||
| inputs = inputs.SkipLast(num_constants).ToTensors(); | |||
| } | |||
| if(inputs.Length > 1) | |||
| { | |||
| initial_state = inputs.Skip(1).ToTensors(); | |||
| inputs = inputs.Take(1).ToTensors(); | |||
| } | |||
| } | |||
| return (inputs, initial_state, constants); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,6 +1,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Common.Types; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training; | |||
| @@ -89,7 +90,7 @@ namespace Tensorflow.Hub | |||
| } | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optionalArgs = null) | |||
| { | |||
| _check_trainability(); | |||
| @@ -144,17 +144,6 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||
| Assert.AreEqual(expected_output, actual_output); | |||
| } | |||
| [TestMethod, Ignore("WIP")] | |||
| public void SimpleRNN() | |||
| { | |||
| var inputs = np.arange(6 * 10 * 8).reshape((6, 10, 8)).astype(np.float32); | |||
| /*var simple_rnn = keras.layers.SimpleRNN(4); | |||
| var output = simple_rnn.Apply(inputs); | |||
| Assert.AreEqual((32, 4), output.shape);*/ | |||
| var simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences: true, return_state: true); | |||
| var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs); | |||
| } | |||
| [TestMethod] | |||
| public void Resizing() | |||
| { | |||
| @@ -0,0 +1,28 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using System.Threading.Tasks; | |||
| using Tensorflow.NumPy; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.UnitTest.Layers | |||
| { | |||
| [TestClass] | |||
| public class Rnn | |||
| { | |||
| [TestMethod] | |||
| public void SimpleRNN() | |||
| { | |||
| var inputs = np.arange(6 * 10 * 8).reshape((6, 10, 8)).astype(np.float32); | |||
| /*var simple_rnn = keras.layers.SimpleRNN(4); | |||
| var output = simple_rnn.Apply(inputs); | |||
| Assert.AreEqual((32, 4), output.shape);*/ | |||
| var simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences: true, return_state: true); | |||
| var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs); | |||
| Console.WriteLine(whole_sequence_output); | |||
| Console.WriteLine(final_state); | |||
| } | |||
| } | |||
| } | |||
| @@ -20,7 +20,7 @@ namespace Tensorflow | |||
| // whole_sequence_output has shape `[32, 10, 4]`. | |||
| // final_state has shape `[32, 4]`. | |||
| var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs); | |||
| var (whole_sequence_output, final_states) = simple_rnn.Apply(inputs); | |||
| } | |||
| } | |||
| } | |||