|
|
|
@@ -1,4 +1,5 @@ |
|
|
|
using Microsoft.VisualStudio.TestTools.UnitTesting; |
|
|
|
using Microsoft.VisualStudio.TestPlatform.Utilities; |
|
|
|
using Microsoft.VisualStudio.TestTools.UnitTesting; |
|
|
|
using System; |
|
|
|
using System.Linq; |
|
|
|
using Tensorflow.NumPy; |
|
|
|
@@ -20,7 +21,7 @@ namespace Tensorflow.Keras.UnitTest.Optimizers |
|
|
|
}; |
|
|
|
} |
|
|
|
|
|
|
|
private void TestBasicGeneric<T>() where T : struct |
|
|
|
private void TestBasic<T>() where T : struct |
|
|
|
{ |
|
|
|
var dtype = GetTypeForNumericType<T>(); |
|
|
|
|
|
|
|
@@ -42,11 +43,9 @@ namespace Tensorflow.Keras.UnitTest.Optimizers |
|
|
|
var global_variables = tf.global_variables_initializer(); |
|
|
|
sess.run(global_variables); |
|
|
|
|
|
|
|
// Fetch params to validate initial values |
|
|
|
var initialVar0 = sess.run(var0); |
|
|
|
var valu = var0.eval(sess); |
|
|
|
var initialVar1 = sess.run(var1); |
|
|
|
// TODO: use self.evaluate<T[]> instead of self.evaluate<double[]> |
|
|
|
// Fetch params to validate initial values |
|
|
|
self.assertAllCloseAccordingToType(new[] { 1.0, 2.0 }, self.evaluate<T[]>(var0)); |
|
|
|
self.assertAllCloseAccordingToType(new[] { 3.0, 4.0 }, self.evaluate<T[]>(var1)); |
|
|
|
// Run 1 step of sgd |
|
|
|
@@ -66,10 +65,9 @@ namespace Tensorflow.Keras.UnitTest.Optimizers |
|
|
|
public void TestBasic() |
|
|
|
{ |
|
|
|
//TODO: add np.half |
|
|
|
TestBasicGeneric<float>(); |
|
|
|
TestBasicGeneric<double>(); |
|
|
|
TestBasic<float>(); |
|
|
|
TestBasic<double>(); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
} |