diff --git a/src/TensorFlowNET.Core/Operations/Operation.Implicit.cs b/src/TensorFlowNET.Core/Operations/Operation.Implicit.cs
index 5ac0cff5..ec49f850 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.Implicit.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.Implicit.cs
@@ -31,8 +31,6 @@ namespace Tensorflow
=> op._handle;
public static implicit operator Tensor(Operation op)
=> op.output;
- public static implicit operator RefVariable(Operation op)
- => new RefVariable(op);
public override string ToString()
{
diff --git a/src/TensorFlowNET.Core/Operations/embedding_ops.cs b/src/TensorFlowNET.Core/Operations/embedding_ops.cs
index fe79e006..2e349ed3 100644
--- a/src/TensorFlowNET.Core/Operations/embedding_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/embedding_ops.cs
@@ -21,37 +21,6 @@ namespace Tensorflow
{
public class embedding_ops
{
- ///
- /// Helper function for embedding_lookup and _compute_sampled_logits.
- ///
- ///
- ///
- ///
- ///
- ///
- public static Tensor _embedding_lookup_and_transform(RefVariable @params,
- Tensor ids,
- string partition_strategy = "mod",
- string name = null,
- string max_norm = null)
- {
- return tf_with(ops.name_scope(name, "embedding_lookup", new { @params, ids }), scope =>
- {
- name = scope;
- int np = 1;
- ids = ops.convert_to_tensor(ids, name: "ids");
- if (np == 1)
- {
- var gather = array_ops.gather(@params, ids, name: name);
- var result = _clip(gather, ids, max_norm);
-
- return array_ops.identity(result);
- }
-
- throw new NotImplementedException("_embedding_lookup_and_transform");
- });
- }
-
///
/// Helper function for embedding_lookup and _compute_sampled_logits.
///
diff --git a/src/TensorFlowNET.Keras/Engine/Model.Compile.cs b/src/TensorFlowNET.Keras/Engine/Model.Compile.cs
index dd91a5de..6f1b7790 100644
--- a/src/TensorFlowNET.Keras/Engine/Model.Compile.cs
+++ b/src/TensorFlowNET.Keras/Engine/Model.Compile.cs
@@ -10,9 +10,16 @@ namespace Tensorflow.Keras.Engine
LossesContainer compiled_loss;
MetricsContainer compiled_metrics;
- public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics)
+ public void compile(OptimizerV2 optimizer = null,
+ ILossFunc loss = null,
+ string[] metrics = null)
{
- this.optimizer = optimizer;
+ this.optimizer = optimizer ?? new RMSprop(new RMSpropArgs
+ {
+ });
+
+ this.loss = loss ?? new MeanSquaredError();
+
compiled_loss = new LossesContainer(loss, output_names: output_names);
compiled_metrics = new MetricsContainer(metrics, output_names: output_names);
@@ -22,7 +29,6 @@ namespace Tensorflow.Keras.Engine
// Initialize cache attrs.
_reset_compile_cache();
_is_compiled = true;
- this.loss = loss;
}
public void compile(string optimizer, string loss, string[] metrics)
diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
index b7cefb66..3ebe526c 100644
--- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs
+++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
@@ -518,6 +518,18 @@ namespace Tensorflow.Keras.Layers
return input_layer.InboundNodes[0].Outputs;
}
+ public InputLayer InputLayer(TensorShape input_shape,
+ string name = null,
+ bool sparse = false,
+ bool ragged = false)
+ => new InputLayer(new InputLayerArgs
+ {
+ InputShape = input_shape,
+ Name = name,
+ Sparse = sparse,
+ Ragged = ragged
+ });
+
///
/// Max pooling operation for 1D temporal data.
///
diff --git a/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj b/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj
index 6567a1ae..419dc308 100644
--- a/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj
+++ b/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj
@@ -36,7 +36,7 @@
-
+
diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs
index 1f4814b9..d9051031 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs
@@ -1,7 +1,8 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
+using System.Collections.Generic;
using Tensorflow;
-using Tensorflow.Operations.Initializers;
+using Tensorflow.Keras;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
@@ -13,6 +14,18 @@ namespace TensorFlowNET.Keras.UnitTest
[TestClass]
public class LayersTest : EagerModeTestBase
{
+ // [TestMethod]
+ public void InputLayer()
+ {
+ var model = keras.Sequential(new List
+ {
+ keras.layers.InputLayer(input_shape: 4),
+ keras.layers.Dense(8)
+ });
+ model.compile(optimizer: keras.optimizers.RMSprop(0.001f));
+ model.fit(np.zeros((10, 4)), np.ones((10, 8)));
+ }
+
[TestMethod]
public void Sequential()
{
diff --git a/test/TensorFlowNET.Keras.UnitTest/Range.cs b/test/TensorFlowNET.Keras.UnitTest/Range.cs
deleted file mode 100644
index be3460f5..00000000
--- a/test/TensorFlowNET.Keras.UnitTest/Range.cs
+++ /dev/null
@@ -1,234 +0,0 @@
-// https://github.com/dotnet/corefx/blob/1597b894a2e9cac668ce6e484506eca778a85197/src/Common/src/CoreLib/System/Index.cs
-// https://github.com/dotnet/corefx/blob/1597b894a2e9cac668ce6e484506eca778a85197/src/Common/src/CoreLib/System/Range.cs
-
-using System.Runtime.CompilerServices;
-
-namespace System
-{
- /// Represent a type can be used to index a collection either from the start or the end.
- ///
- /// Index is used by the C# compiler to support the new index syntax
- ///
- /// int[] someArray = new int[5] { 1, 2, 3, 4, 5 } ;
- /// int lastElement = someArray[^1]; // lastElement = 5
- ///
- ///
- public readonly struct Index : IEquatable
- {
- private readonly int _value;
-
- /// Construct an Index using a value and indicating if the index is from the start or from the end.
- /// The index value. it has to be zero or positive number.
- /// Indicating if the index is from the start or from the end.
- ///
- /// If the Index constructed from the end, index value 1 means pointing at the last element and index value 0 means pointing at beyond last element.
- ///
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public Index(int value, bool fromEnd = false)
- {
- if (value < 0)
- {
- throw new ArgumentOutOfRangeException(nameof(value), "value must be non-negative");
- }
-
- if (fromEnd)
- _value = ~value;
- else
- _value = value;
- }
-
- // The following private constructors mainly created for perf reason to avoid the checks
- private Index(int value)
- {
- _value = value;
- }
-
- /// Create an Index pointing at first element.
- public static Index Start => new Index(0);
-
- /// Create an Index pointing at beyond last element.
- public static Index End => new Index(~0);
-
- /// Create an Index from the start at the position indicated by the value.
- /// The index value from the start.
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static Index FromStart(int value)
- {
- if (value < 0)
- {
- throw new ArgumentOutOfRangeException(nameof(value), "value must be non-negative");
- }
-
- return new Index(value);
- }
-
- /// Create an Index from the end at the position indicated by the value.
- /// The index value from the end.
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static Index FromEnd(int value)
- {
- if (value < 0)
- {
- throw new ArgumentOutOfRangeException(nameof(value), "value must be non-negative");
- }
-
- return new Index(~value);
- }
-
- /// Returns the index value.
- public int Value
- {
- get
- {
- if (_value < 0)
- {
- return ~_value;
- }
- else
- {
- return _value;
- }
- }
- }
-
- /// Indicates whether the index is from the start or the end.
- public bool IsFromEnd => _value < 0;
-
- /// Calculate the offset from the start using the giving collection length.
- /// The length of the collection that the Index will be used with. length has to be a positive value
- ///
- /// For performance reason, we don't validate the input length parameter and the returned offset value against negative values.
- /// we don't validate either the returned offset is greater than the input length.
- /// It is expected Index will be used with collections which always have non negative length/count. If the returned offset is negative and
- /// then used to index a collection will get out of range exception which will be same affect as the validation.
- ///
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public int GetOffset(int length)
- {
- var offset = _value;
- if (IsFromEnd)
- {
- // offset = length - (~value)
- // offset = length + (~(~value) + 1)
- // offset = length + value + 1
-
- offset += length + 1;
- }
- return offset;
- }
-
- /// Indicates whether the current Index object is equal to another object of the same type.
- /// An object to compare with this object
- public override bool Equals(object? value) => value is Index && _value == ((Index)value)._value;
-
- /// Indicates whether the current Index object is equal to another Index object.
- /// An object to compare with this object
- public bool Equals(Index other) => _value == other._value;
-
- /// Returns the hash code for this instance.
- public override int GetHashCode() => _value;
-
- /// Converts integer number to an Index.
- public static implicit operator Index(int value) => FromStart(value);
-
- /// Converts the value of the current Index object to its equivalent string representation.
- public override string ToString()
- {
- if (IsFromEnd)
- return "^" + ((uint)Value).ToString();
-
- return ((uint)Value).ToString();
- }
- }
-
- /// Represent a range has start and end indexes.
- ///
- /// Range is used by the C# compiler to support the range syntax.
- ///
- /// int[] someArray = new int[5] { 1, 2, 3, 4, 5 };
- /// int[] subArray1 = someArray[0..2]; // { 1, 2 }
- /// int[] subArray2 = someArray[1..^0]; // { 2, 3, 4, 5 }
- ///
- ///
- public readonly struct Range : IEquatable
- {
- /// Represent the inclusive start index of the Range.
- public Index Start { get; }
-
- /// Represent the exclusive end index of the Range.
- public Index End { get; }
-
- /// Construct a Range object using the start and end indexes.
- /// Represent the inclusive start index of the range.
- /// Represent the exclusive end index of the range.
- public Range(Index start, Index end)
- {
- Start = start;
- End = end;
- }
-
- /// Indicates whether the current Range object is equal to another object of the same type.
- /// An object to compare with this object
- public override bool Equals(object? value) =>
- value is Range r &&
- r.Start.Equals(Start) &&
- r.End.Equals(End);
-
- /// Indicates whether the current Range object is equal to another Range object.
- /// An object to compare with this object
- public bool Equals(Range other) => other.Start.Equals(Start) && other.End.Equals(End);
-
- /// Returns the hash code for this instance.
- public override int GetHashCode()
- {
- return Start.GetHashCode() * 31 + End.GetHashCode();
- }
-
- /// Converts the value of the current Range object to its equivalent string representation.
- public override string ToString()
- {
- return Start + ".." + End;
- }
-
- /// Create a Range object starting from start index to the end of the collection.
- public static Range StartAt(Index start) => new Range(start, Index.End);
-
- /// Create a Range object starting from first element in the collection to the end Index.
- public static Range EndAt(Index end) => new Range(Index.Start, end);
-
- /// Create a Range object starting from first element to the end.
- public static Range All => new Range(Index.Start, Index.End);
-
- /// Calculate the start offset and length of range object using a collection length.
- /// The length of the collection that the range will be used with. length has to be a positive value.
- ///
- /// For performance reason, we don't validate the input length parameter against negative values.
- /// It is expected Range will be used with collections which always have non negative length/count.
- /// We validate the range is inside the length scope though.
- ///
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public (int Offset, int Length) GetOffsetAndLength(int length)
- {
- int start;
- var startIndex = Start;
- if (startIndex.IsFromEnd)
- start = length - startIndex.Value;
- else
- start = startIndex.Value;
-
- int end;
- var endIndex = End;
- if (endIndex.IsFromEnd)
- end = length - endIndex.Value;
- else
- end = endIndex.Value;
-
- if ((uint)end > (uint)length || (uint)start > (uint)end)
- {
- throw new ArgumentOutOfRangeException(nameof(length));
- }
-
- return (start, end - start);
- }
- }
-}
diff --git a/test/TensorFlowNET.Keras.UnitTest/RuntimeHelpers.cs b/test/TensorFlowNET.Keras.UnitTest/RuntimeHelpers.cs
deleted file mode 100644
index 22f158c4..00000000
--- a/test/TensorFlowNET.Keras.UnitTest/RuntimeHelpers.cs
+++ /dev/null
@@ -1,39 +0,0 @@
-namespace System.Runtime.CompilerServices
-{
- internal static class RuntimeHelpers
- {
- ///
- /// Slices the specified array using the specified range.
- ///
- public static T[] GetSubArray(T[] array, Range range)
- {
- if (array == null)
- {
- throw new ArgumentNullException(nameof(array));
- }
-
- (int offset, int length) = range.GetOffsetAndLength(array.Length);
-
- if (default(T) != null || typeof(T[]) == array.GetType())
- {
- // We know the type of the array to be exactly T[].
-
- if (length == 0)
- {
- return Array.Empty();
- }
-
- var dest = new T[length];
- Array.Copy(array, offset, dest, 0, length);
- return dest;
- }
- else
- {
- // The array is actually a U[] where U:T.
- var dest = (T[])Array.CreateInstance(array.GetType().GetElementType(), length);
- Array.Copy(array, offset, dest, 0, length);
- return dest;
- }
- }
- }
-}
\ No newline at end of file