diff --git a/test/TensorFlowNET.UnitTest/PythonTest.cs b/test/TensorFlowNET.UnitTest/PythonTest.cs index b888f883..9396da58 100644 --- a/test/TensorFlowNET.UnitTest/PythonTest.cs +++ b/test/TensorFlowNET.UnitTest/PythonTest.cs @@ -138,7 +138,7 @@ namespace TensorFlowNET.UnitTest /// public T evaluate(Tensor tensor) { - var results = new Dictionary(); + object result = null; // if context.executing_eagerly(): // return self._eval_helper(tensors) // else: @@ -146,26 +146,25 @@ namespace TensorFlowNET.UnitTest var sess = ops.get_default_session(); if (sess == null) sess = self.session(); - T t_result = (T)(object)null; with(sess, s => { - var ndarray=tensor.eval(); + var ndarray=tensor.eval(); if (typeof(T) == typeof(double)) { - double d = ndarray; - t_result = (T)(object)d; + double x = ndarray; + result=x; } else if (typeof(T) == typeof(int)) { - int d = ndarray; - t_result = (T) (object) d; + int x = ndarray; + result = x; } else { - t_result = (T)(object)ndarray; + result = ndarray; } }); - return t_result; + return (T)result; } } diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs b/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs index 58b5a086..9b259e57 100644 --- a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs +++ b/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs @@ -12,7 +12,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test public class CondTestCases : PythonTest { [TestMethod] - public void testCondTrue() + public void testCondTrue_ConstOnly() { var graph = tf.Graph().as_default(); @@ -31,7 +31,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test } [TestMethod] - public void testCondFalse() + public void testCondFalse_ConstOnly() { var graph = tf.Graph().as_default(); @@ -49,6 +49,40 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test }); } + [TestMethod] + public void testCondTrue() + { + var graph = tf.Graph().as_default(); + + with(tf.Session(graph), sess => + { + var x = tf.constant(2); + var y = tf.constant(5); + var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)), + () => tf.add(y, tf.constant(23))); + //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); + int result = z.eval(sess); + assertEquals(result, 34); + }); + } + + //[Ignore("This Test Fails due to missing edges in the graph!")] + [TestMethod] + public void testCondFalse() + { + var graph = tf.Graph().as_default(); + + with(tf.Session(graph), sess => + { + var x = tf.constant(2); + var y = tf.constant(1); + var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)), + () => tf.add(y, tf.constant(23))); + int result = z.eval(sess); + assertEquals(result, 24); + }); + } + // NOTE: all other test python test cases of this class are either not needed due to strong typing or dest a deprecated api }