diff --git a/src/TensorFlowNET.Core/Operations/Operation.Implicit.cs b/src/TensorFlowNET.Core/Operations/Operation.Implicit.cs index 5ac0cff5..ec49f850 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Implicit.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Implicit.cs @@ -31,8 +31,6 @@ namespace Tensorflow => op._handle; public static implicit operator Tensor(Operation op) => op.output; - public static implicit operator RefVariable(Operation op) - => new RefVariable(op); public override string ToString() { diff --git a/src/TensorFlowNET.Core/Operations/embedding_ops.cs b/src/TensorFlowNET.Core/Operations/embedding_ops.cs index fe79e006..2e349ed3 100644 --- a/src/TensorFlowNET.Core/Operations/embedding_ops.cs +++ b/src/TensorFlowNET.Core/Operations/embedding_ops.cs @@ -21,37 +21,6 @@ namespace Tensorflow { public class embedding_ops { - /// - /// Helper function for embedding_lookup and _compute_sampled_logits. - /// - /// - /// - /// - /// - /// - public static Tensor _embedding_lookup_and_transform(RefVariable @params, - Tensor ids, - string partition_strategy = "mod", - string name = null, - string max_norm = null) - { - return tf_with(ops.name_scope(name, "embedding_lookup", new { @params, ids }), scope => - { - name = scope; - int np = 1; - ids = ops.convert_to_tensor(ids, name: "ids"); - if (np == 1) - { - var gather = array_ops.gather(@params, ids, name: name); - var result = _clip(gather, ids, max_norm); - - return array_ops.identity(result); - } - - throw new NotImplementedException("_embedding_lookup_and_transform"); - }); - } - /// /// Helper function for embedding_lookup and _compute_sampled_logits. /// diff --git a/src/TensorFlowNET.Keras/Engine/Model.Compile.cs b/src/TensorFlowNET.Keras/Engine/Model.Compile.cs index dd91a5de..6f1b7790 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Compile.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Compile.cs @@ -10,9 +10,16 @@ namespace Tensorflow.Keras.Engine LossesContainer compiled_loss; MetricsContainer compiled_metrics; - public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics) + public void compile(OptimizerV2 optimizer = null, + ILossFunc loss = null, + string[] metrics = null) { - this.optimizer = optimizer; + this.optimizer = optimizer ?? new RMSprop(new RMSpropArgs + { + }); + + this.loss = loss ?? new MeanSquaredError(); + compiled_loss = new LossesContainer(loss, output_names: output_names); compiled_metrics = new MetricsContainer(metrics, output_names: output_names); @@ -22,7 +29,6 @@ namespace Tensorflow.Keras.Engine // Initialize cache attrs. _reset_compile_cache(); _is_compiled = true; - this.loss = loss; } public void compile(string optimizer, string loss, string[] metrics) diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index b7cefb66..3ebe526c 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -518,6 +518,18 @@ namespace Tensorflow.Keras.Layers return input_layer.InboundNodes[0].Outputs; } + public InputLayer InputLayer(TensorShape input_shape, + string name = null, + bool sparse = false, + bool ragged = false) + => new InputLayer(new InputLayerArgs + { + InputShape = input_shape, + Name = name, + Sparse = sparse, + Ragged = ragged + }); + /// /// Max pooling operation for 1D temporal data. /// diff --git a/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj b/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj index 6567a1ae..419dc308 100644 --- a/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj +++ b/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj @@ -36,7 +36,7 @@ - + diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs index 1f4814b9..d9051031 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs @@ -1,7 +1,8 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using NumSharp; +using System.Collections.Generic; using Tensorflow; -using Tensorflow.Operations.Initializers; +using Tensorflow.Keras; using static Tensorflow.Binding; using static Tensorflow.KerasApi; @@ -13,6 +14,18 @@ namespace TensorFlowNET.Keras.UnitTest [TestClass] public class LayersTest : EagerModeTestBase { + // [TestMethod] + public void InputLayer() + { + var model = keras.Sequential(new List + { + keras.layers.InputLayer(input_shape: 4), + keras.layers.Dense(8) + }); + model.compile(optimizer: keras.optimizers.RMSprop(0.001f)); + model.fit(np.zeros((10, 4)), np.ones((10, 8))); + } + [TestMethod] public void Sequential() { diff --git a/test/TensorFlowNET.Keras.UnitTest/Range.cs b/test/TensorFlowNET.Keras.UnitTest/Range.cs deleted file mode 100644 index be3460f5..00000000 --- a/test/TensorFlowNET.Keras.UnitTest/Range.cs +++ /dev/null @@ -1,234 +0,0 @@ -// https://github.com/dotnet/corefx/blob/1597b894a2e9cac668ce6e484506eca778a85197/src/Common/src/CoreLib/System/Index.cs -// https://github.com/dotnet/corefx/blob/1597b894a2e9cac668ce6e484506eca778a85197/src/Common/src/CoreLib/System/Range.cs - -using System.Runtime.CompilerServices; - -namespace System -{ - /// Represent a type can be used to index a collection either from the start or the end. - /// - /// Index is used by the C# compiler to support the new index syntax - /// - /// int[] someArray = new int[5] { 1, 2, 3, 4, 5 } ; - /// int lastElement = someArray[^1]; // lastElement = 5 - /// - /// - public readonly struct Index : IEquatable - { - private readonly int _value; - - /// Construct an Index using a value and indicating if the index is from the start or from the end. - /// The index value. it has to be zero or positive number. - /// Indicating if the index is from the start or from the end. - /// - /// If the Index constructed from the end, index value 1 means pointing at the last element and index value 0 means pointing at beyond last element. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public Index(int value, bool fromEnd = false) - { - if (value < 0) - { - throw new ArgumentOutOfRangeException(nameof(value), "value must be non-negative"); - } - - if (fromEnd) - _value = ~value; - else - _value = value; - } - - // The following private constructors mainly created for perf reason to avoid the checks - private Index(int value) - { - _value = value; - } - - /// Create an Index pointing at first element. - public static Index Start => new Index(0); - - /// Create an Index pointing at beyond last element. - public static Index End => new Index(~0); - - /// Create an Index from the start at the position indicated by the value. - /// The index value from the start. - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Index FromStart(int value) - { - if (value < 0) - { - throw new ArgumentOutOfRangeException(nameof(value), "value must be non-negative"); - } - - return new Index(value); - } - - /// Create an Index from the end at the position indicated by the value. - /// The index value from the end. - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Index FromEnd(int value) - { - if (value < 0) - { - throw new ArgumentOutOfRangeException(nameof(value), "value must be non-negative"); - } - - return new Index(~value); - } - - /// Returns the index value. - public int Value - { - get - { - if (_value < 0) - { - return ~_value; - } - else - { - return _value; - } - } - } - - /// Indicates whether the index is from the start or the end. - public bool IsFromEnd => _value < 0; - - /// Calculate the offset from the start using the giving collection length. - /// The length of the collection that the Index will be used with. length has to be a positive value - /// - /// For performance reason, we don't validate the input length parameter and the returned offset value against negative values. - /// we don't validate either the returned offset is greater than the input length. - /// It is expected Index will be used with collections which always have non negative length/count. If the returned offset is negative and - /// then used to index a collection will get out of range exception which will be same affect as the validation. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public int GetOffset(int length) - { - var offset = _value; - if (IsFromEnd) - { - // offset = length - (~value) - // offset = length + (~(~value) + 1) - // offset = length + value + 1 - - offset += length + 1; - } - return offset; - } - - /// Indicates whether the current Index object is equal to another object of the same type. - /// An object to compare with this object - public override bool Equals(object? value) => value is Index && _value == ((Index)value)._value; - - /// Indicates whether the current Index object is equal to another Index object. - /// An object to compare with this object - public bool Equals(Index other) => _value == other._value; - - /// Returns the hash code for this instance. - public override int GetHashCode() => _value; - - /// Converts integer number to an Index. - public static implicit operator Index(int value) => FromStart(value); - - /// Converts the value of the current Index object to its equivalent string representation. - public override string ToString() - { - if (IsFromEnd) - return "^" + ((uint)Value).ToString(); - - return ((uint)Value).ToString(); - } - } - - /// Represent a range has start and end indexes. - /// - /// Range is used by the C# compiler to support the range syntax. - /// - /// int[] someArray = new int[5] { 1, 2, 3, 4, 5 }; - /// int[] subArray1 = someArray[0..2]; // { 1, 2 } - /// int[] subArray2 = someArray[1..^0]; // { 2, 3, 4, 5 } - /// - /// - public readonly struct Range : IEquatable - { - /// Represent the inclusive start index of the Range. - public Index Start { get; } - - /// Represent the exclusive end index of the Range. - public Index End { get; } - - /// Construct a Range object using the start and end indexes. - /// Represent the inclusive start index of the range. - /// Represent the exclusive end index of the range. - public Range(Index start, Index end) - { - Start = start; - End = end; - } - - /// Indicates whether the current Range object is equal to another object of the same type. - /// An object to compare with this object - public override bool Equals(object? value) => - value is Range r && - r.Start.Equals(Start) && - r.End.Equals(End); - - /// Indicates whether the current Range object is equal to another Range object. - /// An object to compare with this object - public bool Equals(Range other) => other.Start.Equals(Start) && other.End.Equals(End); - - /// Returns the hash code for this instance. - public override int GetHashCode() - { - return Start.GetHashCode() * 31 + End.GetHashCode(); - } - - /// Converts the value of the current Range object to its equivalent string representation. - public override string ToString() - { - return Start + ".." + End; - } - - /// Create a Range object starting from start index to the end of the collection. - public static Range StartAt(Index start) => new Range(start, Index.End); - - /// Create a Range object starting from first element in the collection to the end Index. - public static Range EndAt(Index end) => new Range(Index.Start, end); - - /// Create a Range object starting from first element to the end. - public static Range All => new Range(Index.Start, Index.End); - - /// Calculate the start offset and length of range object using a collection length. - /// The length of the collection that the range will be used with. length has to be a positive value. - /// - /// For performance reason, we don't validate the input length parameter against negative values. - /// It is expected Range will be used with collections which always have non negative length/count. - /// We validate the range is inside the length scope though. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public (int Offset, int Length) GetOffsetAndLength(int length) - { - int start; - var startIndex = Start; - if (startIndex.IsFromEnd) - start = length - startIndex.Value; - else - start = startIndex.Value; - - int end; - var endIndex = End; - if (endIndex.IsFromEnd) - end = length - endIndex.Value; - else - end = endIndex.Value; - - if ((uint)end > (uint)length || (uint)start > (uint)end) - { - throw new ArgumentOutOfRangeException(nameof(length)); - } - - return (start, end - start); - } - } -} diff --git a/test/TensorFlowNET.Keras.UnitTest/RuntimeHelpers.cs b/test/TensorFlowNET.Keras.UnitTest/RuntimeHelpers.cs deleted file mode 100644 index 22f158c4..00000000 --- a/test/TensorFlowNET.Keras.UnitTest/RuntimeHelpers.cs +++ /dev/null @@ -1,39 +0,0 @@ -namespace System.Runtime.CompilerServices -{ - internal static class RuntimeHelpers - { - /// - /// Slices the specified array using the specified range. - /// - public static T[] GetSubArray(T[] array, Range range) - { - if (array == null) - { - throw new ArgumentNullException(nameof(array)); - } - - (int offset, int length) = range.GetOffsetAndLength(array.Length); - - if (default(T) != null || typeof(T[]) == array.GetType()) - { - // We know the type of the array to be exactly T[]. - - if (length == 0) - { - return Array.Empty(); - } - - var dest = new T[length]; - Array.Copy(array, offset, dest, 0, length); - return dest; - } - else - { - // The array is actually a U[] where U:T. - var dest = (T[])Array.CreateInstance(array.GetType().GetElementType(), length); - Array.Copy(array, offset, dest, 0, length); - return dest; - } - } - } -} \ No newline at end of file