Browse Source

tests are passing

pull/1184/head
Alexander Novikov 2 years ago
parent
commit
2a377e2f91
3 changed files with 46 additions and 35 deletions
  1. +0
    -8
      src/TensorFlowNET.Core/Variables/variables.py.cs
  2. +26
    -14
      test/TensorFlowNET.UnitTest/PythonTest.cs
  3. +20
    -13
      test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs

+ 0
- 8
src/TensorFlowNET.Core/Variables/variables.py.cs View File

@@ -154,13 +154,5 @@ namespace Tensorflow


return op; return op;
} }

public static Tensor global_variables_initializer()
{
// if context.executing_eagerly():
// return control_flow_ops.no_op(name = "global_variables_initializer")
var group = variables_initializer(global_variables().ToArray());
return group;
}
} }
} }

+ 26
- 14
test/TensorFlowNET.UnitTest/PythonTest.cs View File

@@ -6,6 +6,7 @@ using System.Collections;
using System.Linq; using System.Linq;
using Tensorflow; using Tensorflow;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using System.Collections.Generic;


namespace TensorFlowNET.UnitTest namespace TensorFlowNET.UnitTest
{ {
@@ -144,11 +145,12 @@ namespace TensorFlowNET.UnitTest
Assert.IsTrue(np.allclose(array1, array2, rtol: eps)); Assert.IsTrue(np.allclose(array1, array2, rtol: eps));
} }


private class CollectionComparer : System.Collections.IComparer
private class CollectionComparer : IComparer
{ {
private readonly double _epsilon; private readonly double _epsilon;


public CollectionComparer(double eps = 1e-06) {
public CollectionComparer(double eps = 1e-06)
{
_epsilon = eps; _epsilon = eps;
} }
public int Compare(object x, object y) public int Compare(object x, object y)
@@ -166,13 +168,15 @@ namespace TensorFlowNET.UnitTest
} }


public void assertAllCloseAccordingToType<T>( public void assertAllCloseAccordingToType<T>(
T[] expected,
T[] given,
ICollection expected,
ICollection<T> given,
double eps = 1e-6, double eps = 1e-6,
float float_eps = 1e-6f) float float_eps = 1e-6f)
{ {
// TODO: check if any of arguments is not double and change toletance // TODO: check if any of arguments is not double and change toletance
CollectionAssert.AreEqual(expected, given, new CollectionComparer(eps));
// remove givenAsDouble and cast expected instead
var givenAsDouble = given.Select(x => Convert.ToDouble(x)).ToArray();
CollectionAssert.AreEqual(expected, givenAsDouble, new CollectionComparer(eps));
} }


public void assertProtoEquals(object toProto, object o) public void assertProtoEquals(object toProto, object o)
@@ -241,17 +245,25 @@ namespace TensorFlowNET.UnitTest
// return self._eval_helper(tensors) // return self._eval_helper(tensors)
// else: // else:
{ {
var sess = tf.Session();
var sess = tf.get_default_session();
var ndarray = tensor.eval(sess); var ndarray = tensor.eval(sess);
if (typeof(T) == typeof(double))
if (typeof(T) == typeof(double)
|| typeof(T) == typeof(float)
|| typeof(T) == typeof(int))
{
result = Convert.ChangeType(ndarray, typeof(T));
}
else if (typeof(T) == typeof(double[]))
{
result = ndarray.ToMultiDimArray<double>();
}
else if (typeof(T) == typeof(float[]))
{ {
double x = ndarray;
result = x;
result = ndarray.ToMultiDimArray<float>();
} }
else if (typeof(T) == typeof(int))
else if (typeof(T) == typeof(int[]))
{ {
int x = ndarray;
result = x;
result = ndarray.ToMultiDimArray<int>();
} }
else else
{ {
@@ -457,12 +469,12 @@ namespace TensorFlowNET.UnitTest
else else
{ {


if (crash_if_inconsistent_args && !self._cached_graph.Equals(graph))
if (crash_if_inconsistent_args && self._cached_graph != null && !self._cached_graph.Equals(graph))
throw new ValueError(@"The graph used to get the cached session is throw new ValueError(@"The graph used to get the cached session is
different than the one that was used to create the different than the one that was used to create the
session. Maybe create a new session with session. Maybe create a new session with
self.session()"); self.session()");
if (crash_if_inconsistent_args && !self._cached_config.Equals(config))
if (crash_if_inconsistent_args && self._cached_config != null && !self._cached_config.Equals(config))
{ {
throw new ValueError(@"The config used to get the cached session is throw new ValueError(@"The config used to get the cached session is
different than the one that was used to create the different than the one that was used to create the


+ 20
- 13
test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs View File

@@ -1,8 +1,6 @@
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using System; using System;
using System.Linq; using System.Linq;
using System.Runtime.Intrinsics.X86;
using System.Security.AccessControl;
using Tensorflow.NumPy; using Tensorflow.NumPy;
using TensorFlowNET.UnitTest; using TensorFlowNET.UnitTest;
using static Tensorflow.Binding; using static Tensorflow.Binding;
@@ -12,18 +10,23 @@ namespace Tensorflow.Keras.UnitTest.Optimizers
[TestClass] [TestClass]
public class GradientDescentOptimizerTest : PythonTest public class GradientDescentOptimizerTest : PythonTest
{ {
private void TestBasicGeneric<T>() where T : struct
private static TF_DataType GetTypeForNumericType<T>() where T : struct
{ {
var dtype = Type.GetTypeCode(typeof(T)) switch
return Type.GetTypeCode(typeof(T)) switch
{ {
TypeCode.Single => np.float32, TypeCode.Single => np.float32,
TypeCode.Double => np.float64, TypeCode.Double => np.float64,
_ => throw new NotImplementedException(), _ => throw new NotImplementedException(),
}; };
}

private void TestBasicGeneric<T>() where T : struct
{
var dtype = GetTypeForNumericType<T>();


// train.GradientDescentOptimizer is V1 only API. // train.GradientDescentOptimizer is V1 only API.
tf.Graph().as_default(); tf.Graph().as_default();
using (self.cached_session())
using (var sess = self.cached_session())
{ {
var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype); var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype);
var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype); var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype);
@@ -36,21 +39,25 @@ namespace Tensorflow.Keras.UnitTest.Optimizers
}; };
var sgd_op = optimizer.apply_gradients(grads_and_vars); var sgd_op = optimizer.apply_gradients(grads_and_vars);


var global_variables = variables.global_variables_initializer();
self.evaluate<T>(global_variables);
var global_variables = tf.global_variables_initializer();
sess.run(global_variables);

// Fetch params to validate initial values // Fetch params to validate initial values
var initialVar0 = sess.run(var0);
var valu = var0.eval(sess);
var initialVar1 = sess.run(var1);
// TODO: use self.evaluate<T[]> instead of self.evaluate<double[]> // TODO: use self.evaluate<T[]> instead of self.evaluate<double[]>
self.assertAllCloseAccordingToType(new double[] { 1.0, 2.0 }, self.evaluate<double[]>(var0));
self.assertAllCloseAccordingToType(new double[] { 3.0, 4.0 }, self.evaluate<double[]>(var1));
self.assertAllCloseAccordingToType(new[] { 1.0, 2.0 }, self.evaluate<T[]>(var0));
self.assertAllCloseAccordingToType(new[] { 3.0, 4.0 }, self.evaluate<T[]>(var1));
// Run 1 step of sgd // Run 1 step of sgd
sgd_op.run(); sgd_op.run();
// Validate updated params // Validate updated params
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
new double[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 },
self.evaluate<double[]>(var0));
new[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 },
self.evaluate<T[]>(var0));
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
new double[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 },
self.evaluate<double[]>(var1));
new[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 },
self.evaluate<T[]>(var1));
// TODO: self.assertEqual(0, len(optimizer.variables())); // TODO: self.assertEqual(0, len(optimizer.variables()));
} }
} }


Loading…
Cancel
Save