| @@ -1,4 +1,5 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using Microsoft.VisualStudio.TestPlatform.Utilities; | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using System; | using System; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
| @@ -20,7 +21,7 @@ namespace Tensorflow.Keras.UnitTest.Optimizers | |||||
| }; | }; | ||||
| } | } | ||||
| private void TestBasicGeneric<T>() where T : struct | |||||
| private void TestBasic<T>() where T : struct | |||||
| { | { | ||||
| var dtype = GetTypeForNumericType<T>(); | var dtype = GetTypeForNumericType<T>(); | ||||
| @@ -42,11 +43,9 @@ namespace Tensorflow.Keras.UnitTest.Optimizers | |||||
| var global_variables = tf.global_variables_initializer(); | var global_variables = tf.global_variables_initializer(); | ||||
| sess.run(global_variables); | sess.run(global_variables); | ||||
| // Fetch params to validate initial values | |||||
| var initialVar0 = sess.run(var0); | var initialVar0 = sess.run(var0); | ||||
| var valu = var0.eval(sess); | |||||
| var initialVar1 = sess.run(var1); | var initialVar1 = sess.run(var1); | ||||
| // TODO: use self.evaluate<T[]> instead of self.evaluate<double[]> | |||||
| // Fetch params to validate initial values | |||||
| self.assertAllCloseAccordingToType(new[] { 1.0, 2.0 }, self.evaluate<T[]>(var0)); | self.assertAllCloseAccordingToType(new[] { 1.0, 2.0 }, self.evaluate<T[]>(var0)); | ||||
| self.assertAllCloseAccordingToType(new[] { 3.0, 4.0 }, self.evaluate<T[]>(var1)); | self.assertAllCloseAccordingToType(new[] { 3.0, 4.0 }, self.evaluate<T[]>(var1)); | ||||
| // Run 1 step of sgd | // Run 1 step of sgd | ||||
| @@ -66,10 +65,9 @@ namespace Tensorflow.Keras.UnitTest.Optimizers | |||||
| public void TestBasic() | public void TestBasic() | ||||
| { | { | ||||
| //TODO: add np.half | //TODO: add np.half | ||||
| TestBasicGeneric<float>(); | |||||
| TestBasicGeneric<double>(); | |||||
| TestBasic<float>(); | |||||
| TestBasic<double>(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||