Browse Source

small fixes

pull/1184/head
Alexander 2 years ago
parent
commit
f7b8dba00b
1 changed files with 6 additions and 8 deletions
  1. +6
    -8
      test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs

+ 6
- 8
test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs View File

@@ -1,4 +1,5 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Microsoft.VisualStudio.TestPlatform.Utilities;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Linq;
using Tensorflow.NumPy;
@@ -20,7 +21,7 @@ namespace Tensorflow.Keras.UnitTest.Optimizers
};
}

private void TestBasicGeneric<T>() where T : struct
private void TestBasic<T>() where T : struct
{
var dtype = GetTypeForNumericType<T>();

@@ -42,11 +43,9 @@ namespace Tensorflow.Keras.UnitTest.Optimizers
var global_variables = tf.global_variables_initializer();
sess.run(global_variables);

// Fetch params to validate initial values
var initialVar0 = sess.run(var0);
var valu = var0.eval(sess);
var initialVar1 = sess.run(var1);
// TODO: use self.evaluate<T[]> instead of self.evaluate<double[]>
// Fetch params to validate initial values
self.assertAllCloseAccordingToType(new[] { 1.0, 2.0 }, self.evaluate<T[]>(var0));
self.assertAllCloseAccordingToType(new[] { 3.0, 4.0 }, self.evaluate<T[]>(var1));
// Run 1 step of sgd
@@ -66,10 +65,9 @@ namespace Tensorflow.Keras.UnitTest.Optimizers
public void TestBasic()
{
//TODO: add np.half
TestBasicGeneric<float>();
TestBasicGeneric<double>();
TestBasic<float>();
TestBasic<double>();
}


}
}

Loading…
Cancel
Save