| @@ -154,13 +154,5 @@ namespace Tensorflow | |||
| return op; | |||
| } | |||
| public static Tensor global_variables_initializer() | |||
| { | |||
| // if context.executing_eagerly(): | |||
| // return control_flow_ops.no_op(name = "global_variables_initializer") | |||
| var group = variables_initializer(global_variables().ToArray()); | |||
| return group; | |||
| } | |||
| } | |||
| } | |||
| @@ -6,6 +6,7 @@ using System.Collections; | |||
| using System.Linq; | |||
| using Tensorflow; | |||
| using static Tensorflow.Binding; | |||
| using System.Collections.Generic; | |||
| namespace TensorFlowNET.UnitTest | |||
| { | |||
| @@ -144,11 +145,12 @@ namespace TensorFlowNET.UnitTest | |||
| Assert.IsTrue(np.allclose(array1, array2, rtol: eps)); | |||
| } | |||
| private class CollectionComparer : System.Collections.IComparer | |||
| private class CollectionComparer : IComparer | |||
| { | |||
| private readonly double _epsilon; | |||
| public CollectionComparer(double eps = 1e-06) { | |||
| public CollectionComparer(double eps = 1e-06) | |||
| { | |||
| _epsilon = eps; | |||
| } | |||
| public int Compare(object x, object y) | |||
| @@ -166,13 +168,15 @@ namespace TensorFlowNET.UnitTest | |||
| } | |||
| public void assertAllCloseAccordingToType<T>( | |||
| T[] expected, | |||
| T[] given, | |||
| ICollection expected, | |||
| ICollection<T> given, | |||
| double eps = 1e-6, | |||
| float float_eps = 1e-6f) | |||
| { | |||
| // TODO: check if any of arguments is not double and change toletance | |||
| CollectionAssert.AreEqual(expected, given, new CollectionComparer(eps)); | |||
| // remove givenAsDouble and cast expected instead | |||
| var givenAsDouble = given.Select(x => Convert.ToDouble(x)).ToArray(); | |||
| CollectionAssert.AreEqual(expected, givenAsDouble, new CollectionComparer(eps)); | |||
| } | |||
| public void assertProtoEquals(object toProto, object o) | |||
| @@ -241,17 +245,25 @@ namespace TensorFlowNET.UnitTest | |||
| // return self._eval_helper(tensors) | |||
| // else: | |||
| { | |||
| var sess = tf.Session(); | |||
| var sess = tf.get_default_session(); | |||
| var ndarray = tensor.eval(sess); | |||
| if (typeof(T) == typeof(double)) | |||
| if (typeof(T) == typeof(double) | |||
| || typeof(T) == typeof(float) | |||
| || typeof(T) == typeof(int)) | |||
| { | |||
| result = Convert.ChangeType(ndarray, typeof(T)); | |||
| } | |||
| else if (typeof(T) == typeof(double[])) | |||
| { | |||
| result = ndarray.ToMultiDimArray<double>(); | |||
| } | |||
| else if (typeof(T) == typeof(float[])) | |||
| { | |||
| double x = ndarray; | |||
| result = x; | |||
| result = ndarray.ToMultiDimArray<float>(); | |||
| } | |||
| else if (typeof(T) == typeof(int)) | |||
| else if (typeof(T) == typeof(int[])) | |||
| { | |||
| int x = ndarray; | |||
| result = x; | |||
| result = ndarray.ToMultiDimArray<int>(); | |||
| } | |||
| else | |||
| { | |||
| @@ -457,12 +469,12 @@ namespace TensorFlowNET.UnitTest | |||
| else | |||
| { | |||
| if (crash_if_inconsistent_args && !self._cached_graph.Equals(graph)) | |||
| if (crash_if_inconsistent_args && self._cached_graph != null && !self._cached_graph.Equals(graph)) | |||
| throw new ValueError(@"The graph used to get the cached session is | |||
| different than the one that was used to create the | |||
| session. Maybe create a new session with | |||
| self.session()"); | |||
| if (crash_if_inconsistent_args && !self._cached_config.Equals(config)) | |||
| if (crash_if_inconsistent_args && self._cached_config != null && !self._cached_config.Equals(config)) | |||
| { | |||
| throw new ValueError(@"The config used to get the cached session is | |||
| different than the one that was used to create the | |||
| @@ -1,8 +1,6 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using System.Linq; | |||
| using System.Runtime.Intrinsics.X86; | |||
| using System.Security.AccessControl; | |||
| using Tensorflow.NumPy; | |||
| using TensorFlowNET.UnitTest; | |||
| using static Tensorflow.Binding; | |||
| @@ -12,18 +10,23 @@ namespace Tensorflow.Keras.UnitTest.Optimizers | |||
| [TestClass] | |||
| public class GradientDescentOptimizerTest : PythonTest | |||
| { | |||
| private void TestBasicGeneric<T>() where T : struct | |||
| private static TF_DataType GetTypeForNumericType<T>() where T : struct | |||
| { | |||
| var dtype = Type.GetTypeCode(typeof(T)) switch | |||
| return Type.GetTypeCode(typeof(T)) switch | |||
| { | |||
| TypeCode.Single => np.float32, | |||
| TypeCode.Double => np.float64, | |||
| _ => throw new NotImplementedException(), | |||
| }; | |||
| } | |||
| private void TestBasicGeneric<T>() where T : struct | |||
| { | |||
| var dtype = GetTypeForNumericType<T>(); | |||
| // train.GradientDescentOptimizer is V1 only API. | |||
| tf.Graph().as_default(); | |||
| using (self.cached_session()) | |||
| using (var sess = self.cached_session()) | |||
| { | |||
| var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype); | |||
| var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype); | |||
| @@ -36,21 +39,25 @@ namespace Tensorflow.Keras.UnitTest.Optimizers | |||
| }; | |||
| var sgd_op = optimizer.apply_gradients(grads_and_vars); | |||
| var global_variables = variables.global_variables_initializer(); | |||
| self.evaluate<T>(global_variables); | |||
| var global_variables = tf.global_variables_initializer(); | |||
| sess.run(global_variables); | |||
| // Fetch params to validate initial values | |||
| var initialVar0 = sess.run(var0); | |||
| var valu = var0.eval(sess); | |||
| var initialVar1 = sess.run(var1); | |||
| // TODO: use self.evaluate<T[]> instead of self.evaluate<double[]> | |||
| self.assertAllCloseAccordingToType(new double[] { 1.0, 2.0 }, self.evaluate<double[]>(var0)); | |||
| self.assertAllCloseAccordingToType(new double[] { 3.0, 4.0 }, self.evaluate<double[]>(var1)); | |||
| self.assertAllCloseAccordingToType(new[] { 1.0, 2.0 }, self.evaluate<T[]>(var0)); | |||
| self.assertAllCloseAccordingToType(new[] { 3.0, 4.0 }, self.evaluate<T[]>(var1)); | |||
| // Run 1 step of sgd | |||
| sgd_op.run(); | |||
| // Validate updated params | |||
| self.assertAllCloseAccordingToType( | |||
| new double[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 }, | |||
| self.evaluate<double[]>(var0)); | |||
| new[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 }, | |||
| self.evaluate<T[]>(var0)); | |||
| self.assertAllCloseAccordingToType( | |||
| new double[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 }, | |||
| self.evaluate<double[]>(var1)); | |||
| new[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 }, | |||
| self.evaluate<T[]>(var1)); | |||
| // TODO: self.assertEqual(0, len(optimizer.variables())); | |||
| } | |||
| } | |||