| @@ -154,13 +154,5 @@ namespace Tensorflow | |||||
| return op; | return op; | ||||
| } | } | ||||
| public static Tensor global_variables_initializer() | |||||
| { | |||||
| // if context.executing_eagerly(): | |||||
| // return control_flow_ops.no_op(name = "global_variables_initializer") | |||||
| var group = variables_initializer(global_variables().ToArray()); | |||||
| return group; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -6,6 +6,7 @@ using System.Collections; | |||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using System.Collections.Generic; | |||||
| namespace TensorFlowNET.UnitTest | namespace TensorFlowNET.UnitTest | ||||
| { | { | ||||
| @@ -144,11 +145,12 @@ namespace TensorFlowNET.UnitTest | |||||
| Assert.IsTrue(np.allclose(array1, array2, rtol: eps)); | Assert.IsTrue(np.allclose(array1, array2, rtol: eps)); | ||||
| } | } | ||||
| private class CollectionComparer : System.Collections.IComparer | |||||
| private class CollectionComparer : IComparer | |||||
| { | { | ||||
| private readonly double _epsilon; | private readonly double _epsilon; | ||||
| public CollectionComparer(double eps = 1e-06) { | |||||
| public CollectionComparer(double eps = 1e-06) | |||||
| { | |||||
| _epsilon = eps; | _epsilon = eps; | ||||
| } | } | ||||
| public int Compare(object x, object y) | public int Compare(object x, object y) | ||||
| @@ -166,13 +168,15 @@ namespace TensorFlowNET.UnitTest | |||||
| } | } | ||||
| public void assertAllCloseAccordingToType<T>( | public void assertAllCloseAccordingToType<T>( | ||||
| T[] expected, | |||||
| T[] given, | |||||
| ICollection expected, | |||||
| ICollection<T> given, | |||||
| double eps = 1e-6, | double eps = 1e-6, | ||||
| float float_eps = 1e-6f) | float float_eps = 1e-6f) | ||||
| { | { | ||||
| // TODO: check if any of arguments is not double and change toletance | // TODO: check if any of arguments is not double and change toletance | ||||
| CollectionAssert.AreEqual(expected, given, new CollectionComparer(eps)); | |||||
| // remove givenAsDouble and cast expected instead | |||||
| var givenAsDouble = given.Select(x => Convert.ToDouble(x)).ToArray(); | |||||
| CollectionAssert.AreEqual(expected, givenAsDouble, new CollectionComparer(eps)); | |||||
| } | } | ||||
| public void assertProtoEquals(object toProto, object o) | public void assertProtoEquals(object toProto, object o) | ||||
| @@ -241,17 +245,25 @@ namespace TensorFlowNET.UnitTest | |||||
| // return self._eval_helper(tensors) | // return self._eval_helper(tensors) | ||||
| // else: | // else: | ||||
| { | { | ||||
| var sess = tf.Session(); | |||||
| var sess = tf.get_default_session(); | |||||
| var ndarray = tensor.eval(sess); | var ndarray = tensor.eval(sess); | ||||
| if (typeof(T) == typeof(double)) | |||||
| if (typeof(T) == typeof(double) | |||||
| || typeof(T) == typeof(float) | |||||
| || typeof(T) == typeof(int)) | |||||
| { | |||||
| result = Convert.ChangeType(ndarray, typeof(T)); | |||||
| } | |||||
| else if (typeof(T) == typeof(double[])) | |||||
| { | |||||
| result = ndarray.ToMultiDimArray<double>(); | |||||
| } | |||||
| else if (typeof(T) == typeof(float[])) | |||||
| { | { | ||||
| double x = ndarray; | |||||
| result = x; | |||||
| result = ndarray.ToMultiDimArray<float>(); | |||||
| } | } | ||||
| else if (typeof(T) == typeof(int)) | |||||
| else if (typeof(T) == typeof(int[])) | |||||
| { | { | ||||
| int x = ndarray; | |||||
| result = x; | |||||
| result = ndarray.ToMultiDimArray<int>(); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -457,12 +469,12 @@ namespace TensorFlowNET.UnitTest | |||||
| else | else | ||||
| { | { | ||||
| if (crash_if_inconsistent_args && !self._cached_graph.Equals(graph)) | |||||
| if (crash_if_inconsistent_args && self._cached_graph != null && !self._cached_graph.Equals(graph)) | |||||
| throw new ValueError(@"The graph used to get the cached session is | throw new ValueError(@"The graph used to get the cached session is | ||||
| different than the one that was used to create the | different than the one that was used to create the | ||||
| session. Maybe create a new session with | session. Maybe create a new session with | ||||
| self.session()"); | self.session()"); | ||||
| if (crash_if_inconsistent_args && !self._cached_config.Equals(config)) | |||||
| if (crash_if_inconsistent_args && self._cached_config != null && !self._cached_config.Equals(config)) | |||||
| { | { | ||||
| throw new ValueError(@"The config used to get the cached session is | throw new ValueError(@"The config used to get the cached session is | ||||
| different than the one that was used to create the | different than the one that was used to create the | ||||
| @@ -1,8 +1,6 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using System; | using System; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.Intrinsics.X86; | |||||
| using System.Security.AccessControl; | |||||
| using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
| using TensorFlowNET.UnitTest; | using TensorFlowNET.UnitTest; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -12,18 +10,23 @@ namespace Tensorflow.Keras.UnitTest.Optimizers | |||||
| [TestClass] | [TestClass] | ||||
| public class GradientDescentOptimizerTest : PythonTest | public class GradientDescentOptimizerTest : PythonTest | ||||
| { | { | ||||
| private void TestBasicGeneric<T>() where T : struct | |||||
| private static TF_DataType GetTypeForNumericType<T>() where T : struct | |||||
| { | { | ||||
| var dtype = Type.GetTypeCode(typeof(T)) switch | |||||
| return Type.GetTypeCode(typeof(T)) switch | |||||
| { | { | ||||
| TypeCode.Single => np.float32, | TypeCode.Single => np.float32, | ||||
| TypeCode.Double => np.float64, | TypeCode.Double => np.float64, | ||||
| _ => throw new NotImplementedException(), | _ => throw new NotImplementedException(), | ||||
| }; | }; | ||||
| } | |||||
| private void TestBasicGeneric<T>() where T : struct | |||||
| { | |||||
| var dtype = GetTypeForNumericType<T>(); | |||||
| // train.GradientDescentOptimizer is V1 only API. | // train.GradientDescentOptimizer is V1 only API. | ||||
| tf.Graph().as_default(); | tf.Graph().as_default(); | ||||
| using (self.cached_session()) | |||||
| using (var sess = self.cached_session()) | |||||
| { | { | ||||
| var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype); | var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype); | ||||
| var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype); | var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype); | ||||
| @@ -36,21 +39,25 @@ namespace Tensorflow.Keras.UnitTest.Optimizers | |||||
| }; | }; | ||||
| var sgd_op = optimizer.apply_gradients(grads_and_vars); | var sgd_op = optimizer.apply_gradients(grads_and_vars); | ||||
| var global_variables = variables.global_variables_initializer(); | |||||
| self.evaluate<T>(global_variables); | |||||
| var global_variables = tf.global_variables_initializer(); | |||||
| sess.run(global_variables); | |||||
| // Fetch params to validate initial values | // Fetch params to validate initial values | ||||
| var initialVar0 = sess.run(var0); | |||||
| var valu = var0.eval(sess); | |||||
| var initialVar1 = sess.run(var1); | |||||
| // TODO: use self.evaluate<T[]> instead of self.evaluate<double[]> | // TODO: use self.evaluate<T[]> instead of self.evaluate<double[]> | ||||
| self.assertAllCloseAccordingToType(new double[] { 1.0, 2.0 }, self.evaluate<double[]>(var0)); | |||||
| self.assertAllCloseAccordingToType(new double[] { 3.0, 4.0 }, self.evaluate<double[]>(var1)); | |||||
| self.assertAllCloseAccordingToType(new[] { 1.0, 2.0 }, self.evaluate<T[]>(var0)); | |||||
| self.assertAllCloseAccordingToType(new[] { 3.0, 4.0 }, self.evaluate<T[]>(var1)); | |||||
| // Run 1 step of sgd | // Run 1 step of sgd | ||||
| sgd_op.run(); | sgd_op.run(); | ||||
| // Validate updated params | // Validate updated params | ||||
| self.assertAllCloseAccordingToType( | self.assertAllCloseAccordingToType( | ||||
| new double[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 }, | |||||
| self.evaluate<double[]>(var0)); | |||||
| new[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 }, | |||||
| self.evaluate<T[]>(var0)); | |||||
| self.assertAllCloseAccordingToType( | self.assertAllCloseAccordingToType( | ||||
| new double[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 }, | |||||
| self.evaluate<double[]>(var1)); | |||||
| new[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 }, | |||||
| self.evaluate<T[]>(var1)); | |||||
| // TODO: self.assertEqual(0, len(optimizer.variables())); | // TODO: self.assertEqual(0, len(optimizer.variables())); | ||||
| } | } | ||||
| } | } | ||||