From 2a377e2f91b40083f5de86f01b57b32bad5a5932 Mon Sep 17 00:00:00 2001 From: Alexander Novikov Date: Tue, 7 Nov 2023 19:23:34 +0000 Subject: [PATCH] tests are passing --- .../Variables/variables.py.cs | 8 ---- test/TensorFlowNET.UnitTest/PythonTest.cs | 40 ++++++++++++------- .../Training/GradientDescentOptimizerTests.cs | 33 +++++++++------ 3 files changed, 46 insertions(+), 35 deletions(-) diff --git a/src/TensorFlowNET.Core/Variables/variables.py.cs b/src/TensorFlowNET.Core/Variables/variables.py.cs index f3ae248e..91f57e29 100644 --- a/src/TensorFlowNET.Core/Variables/variables.py.cs +++ b/src/TensorFlowNET.Core/Variables/variables.py.cs @@ -154,13 +154,5 @@ namespace Tensorflow return op; } - - public static Tensor global_variables_initializer() - { - // if context.executing_eagerly(): - // return control_flow_ops.no_op(name = "global_variables_initializer") - var group = variables_initializer(global_variables().ToArray()); - return group; - } } } diff --git a/test/TensorFlowNET.UnitTest/PythonTest.cs b/test/TensorFlowNET.UnitTest/PythonTest.cs index 12fd7236..090ef097 100644 --- a/test/TensorFlowNET.UnitTest/PythonTest.cs +++ b/test/TensorFlowNET.UnitTest/PythonTest.cs @@ -6,6 +6,7 @@ using System.Collections; using System.Linq; using Tensorflow; using static Tensorflow.Binding; +using System.Collections.Generic; namespace TensorFlowNET.UnitTest { @@ -144,11 +145,12 @@ namespace TensorFlowNET.UnitTest Assert.IsTrue(np.allclose(array1, array2, rtol: eps)); } - private class CollectionComparer : System.Collections.IComparer + private class CollectionComparer : IComparer { private readonly double _epsilon; - public CollectionComparer(double eps = 1e-06) { + public CollectionComparer(double eps = 1e-06) + { _epsilon = eps; } public int Compare(object x, object y) @@ -166,13 +168,15 @@ namespace TensorFlowNET.UnitTest } public void assertAllCloseAccordingToType( - T[] expected, - T[] given, + ICollection expected, + ICollection given, double eps = 1e-6, float float_eps = 1e-6f) { // TODO: check if any of arguments is not double and change toletance - CollectionAssert.AreEqual(expected, given, new CollectionComparer(eps)); + // remove givenAsDouble and cast expected instead + var givenAsDouble = given.Select(x => Convert.ToDouble(x)).ToArray(); + CollectionAssert.AreEqual(expected, givenAsDouble, new CollectionComparer(eps)); } public void assertProtoEquals(object toProto, object o) @@ -241,17 +245,25 @@ namespace TensorFlowNET.UnitTest // return self._eval_helper(tensors) // else: { - var sess = tf.Session(); + var sess = tf.get_default_session(); var ndarray = tensor.eval(sess); - if (typeof(T) == typeof(double)) + if (typeof(T) == typeof(double) + || typeof(T) == typeof(float) + || typeof(T) == typeof(int)) + { + result = Convert.ChangeType(ndarray, typeof(T)); + } + else if (typeof(T) == typeof(double[])) + { + result = ndarray.ToMultiDimArray(); + } + else if (typeof(T) == typeof(float[])) { - double x = ndarray; - result = x; + result = ndarray.ToMultiDimArray(); } - else if (typeof(T) == typeof(int)) + else if (typeof(T) == typeof(int[])) { - int x = ndarray; - result = x; + result = ndarray.ToMultiDimArray(); } else { @@ -457,12 +469,12 @@ namespace TensorFlowNET.UnitTest else { - if (crash_if_inconsistent_args && !self._cached_graph.Equals(graph)) + if (crash_if_inconsistent_args && self._cached_graph != null && !self._cached_graph.Equals(graph)) throw new ValueError(@"The graph used to get the cached session is different than the one that was used to create the session. Maybe create a new session with self.session()"); - if (crash_if_inconsistent_args && !self._cached_config.Equals(config)) + if (crash_if_inconsistent_args && self._cached_config != null && !self._cached_config.Equals(config)) { throw new ValueError(@"The config used to get the cached session is different than the one that was used to create the diff --git a/test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs b/test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs index 977544ae..3059068f 100644 --- a/test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs +++ b/test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs @@ -1,8 +1,6 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; using System.Linq; -using System.Runtime.Intrinsics.X86; -using System.Security.AccessControl; using Tensorflow.NumPy; using TensorFlowNET.UnitTest; using static Tensorflow.Binding; @@ -12,18 +10,23 @@ namespace Tensorflow.Keras.UnitTest.Optimizers [TestClass] public class GradientDescentOptimizerTest : PythonTest { - private void TestBasicGeneric() where T : struct + private static TF_DataType GetTypeForNumericType() where T : struct { - var dtype = Type.GetTypeCode(typeof(T)) switch + return Type.GetTypeCode(typeof(T)) switch { TypeCode.Single => np.float32, TypeCode.Double => np.float64, _ => throw new NotImplementedException(), }; + } + + private void TestBasicGeneric() where T : struct + { + var dtype = GetTypeForNumericType(); // train.GradientDescentOptimizer is V1 only API. tf.Graph().as_default(); - using (self.cached_session()) + using (var sess = self.cached_session()) { var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype); var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype); @@ -36,21 +39,25 @@ namespace Tensorflow.Keras.UnitTest.Optimizers }; var sgd_op = optimizer.apply_gradients(grads_and_vars); - var global_variables = variables.global_variables_initializer(); - self.evaluate(global_variables); + var global_variables = tf.global_variables_initializer(); + sess.run(global_variables); + // Fetch params to validate initial values + var initialVar0 = sess.run(var0); + var valu = var0.eval(sess); + var initialVar1 = sess.run(var1); // TODO: use self.evaluate instead of self.evaluate - self.assertAllCloseAccordingToType(new double[] { 1.0, 2.0 }, self.evaluate(var0)); - self.assertAllCloseAccordingToType(new double[] { 3.0, 4.0 }, self.evaluate(var1)); + self.assertAllCloseAccordingToType(new[] { 1.0, 2.0 }, self.evaluate(var0)); + self.assertAllCloseAccordingToType(new[] { 3.0, 4.0 }, self.evaluate(var1)); // Run 1 step of sgd sgd_op.run(); // Validate updated params self.assertAllCloseAccordingToType( - new double[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 }, - self.evaluate(var0)); + new[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 }, + self.evaluate(var0)); self.assertAllCloseAccordingToType( - new double[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 }, - self.evaluate(var1)); + new[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 }, + self.evaluate(var1)); // TODO: self.assertEqual(0, len(optimizer.variables())); } }