diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs
index e43dd913..a610f4e7 100644
--- a/src/TensorFlowNET.Core/Sessions/Session.cs
+++ b/src/TensorFlowNET.Core/Sessions/Session.cs
@@ -38,6 +38,12 @@ namespace Tensorflow
Status.Check(true);
}
+ public Session as_default()
+ {
+ tf.defaultSession = this;
+ return this;
+ }
+
public static Session LoadFromSavedModel(string path)
{
var graph = c_api.TF_NewGraph();
diff --git a/test/TensorFlowNET.UnitTest/PythonTest.cs b/test/TensorFlowNET.UnitTest/PythonTest.cs
index 5d9bb374..508a0a81 100644
--- a/test/TensorFlowNET.UnitTest/PythonTest.cs
+++ b/test/TensorFlowNET.UnitTest/PythonTest.cs
@@ -132,23 +132,66 @@ namespace TensorFlowNET.UnitTest
}
///
- /// Evaluates tensors and returns numpy values.
+ /// Evaluates tensors and returns a dictionary of {name:result, ...}.
/// A Tensor or a nested list/tuple of Tensors.
///
- /// tensors numpy values.
- public object evaluate(params Tensor[] tensors)
+ public Dictionary evaluate(params Tensor[] tensors)
{
+ var results = new Dictionary();
// if context.executing_eagerly():
// return self._eval_helper(tensors)
// else:
{
var sess = ops.get_default_session();
- if (sess == None)
- with(self.session(), s => sess = s);
- return sess.run(tensors);
+ if (sess == null)
+ sess = self.session();
+
+ with(sess, s =>
+ {
+ foreach (var t in tensors)
+ results[t.name] = t.eval();
+ });
+ return results;
+ }
+ }
+
+ public NDArray evaluate(Tensor tensor)
+ {
+ NDArray result = null;
+ // if context.executing_eagerly():
+ // return self._eval_helper(tensors)
+ // else:
+ {
+ var sess = ops.get_default_session();
+ if (sess == null)
+ sess = self.session();
+ with(sess, s =>
+ {
+ result = tensor.eval();
+ });
+ return result;
+ }
+ }
+
+ public object eval_scalar(Tensor tensor)
+ {
+ NDArray result = null;
+ // if context.executing_eagerly():
+ // return self._eval_helper(tensors)
+ // else:
+ {
+ var sess = ops.get_default_session();
+ if (sess == null)
+ sess = self.session();
+ with(sess, s =>
+ {
+ result = tensor.eval();
+ });
+ return result.Array.GetValue(0);
}
}
+
//Returns a TensorFlow Session for use in executing tests.
public Session session(Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false)
{
@@ -188,16 +231,11 @@ namespace TensorFlowNET.UnitTest
//if (context.executing_eagerly())
// yield None
//else
- {
- with(self._create_session(graph, config, force_gpu), sess =>
- {
- with(self._constrain_devices_and_set_default(sess, use_gpu, force_gpu), (x) =>
- {
- s = sess;
- });
- });
- }
- return s;
+ //{
+ s = self._create_session(graph, config, force_gpu);
+ self._constrain_devices_and_set_default(s, use_gpu, force_gpu);
+ //}
+ return s.as_default();
}
private IPython _constrain_devices_and_set_default(Session sess, bool useGpu, bool forceGpu)
diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs b/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs
index 85908baf..0b5f0879 100644
--- a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs
+++ b/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs
@@ -13,53 +13,29 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
[TestMethod]
public void testCondTrue()
{
- var x = tf.constant(2);
- var y = tf.constant(5);
- var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)),
- () => tf.add(y, tf.constant(23)));
- self.assertEquals(self.evaluate(z), 34);
+ with(tf.Graph().as_default(), g =>
+ {
+ var x = tf.constant(2);
+ var y = tf.constant(5);
+ var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)),
+ () => tf.add(y, tf.constant(23)));
+ //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);
+ self.assertEquals(eval_scalar(z), 34);
+ });
}
- [Ignore("Todo")]
+ [Ignore("This Test Fails due to missing edges in the graph!")]
[TestMethod]
public void testCondFalse()
{
- // def testCondFalse(self):
- // x = constant_op.constant(2)
- // y = constant_op.constant(1)
- // z = control_flow_ops.cond(
- // math_ops.less(
- // x,
- // y), lambda: math_ops.multiply(x, 17), lambda: math_ops.add(y, 23))
- // self.assertEquals(self.evaluate(z), 24)
- }
-
- [Ignore("Todo")]
- [TestMethod]
- public void testCondTrueLegacy()
- {
- // def testCondTrueLegacy(self):
- // x = constant_op.constant(2)
- // y = constant_op.constant(5)
- // z = control_flow_ops.cond(
- // math_ops.less(x, y),
- // fn1=lambda: math_ops.multiply(x, 17),
- // fn2=lambda: math_ops.add(y, 23))
- // self.assertEquals(self.evaluate(z), 34)
- }
-
- [Ignore("Todo")]
- [TestMethod]
- public void testCondFalseLegacy()
- {
- // def testCondFalseLegacy(self):
- // x = constant_op.constant(2)
- // y = constant_op.constant(1)
- // z = control_flow_ops.cond(
- // math_ops.less(x, y),
- // fn1=lambda: math_ops.multiply(x, 17),
- // fn2=lambda: math_ops.add(y, 23))
- // self.assertEquals(self.evaluate(z), 24)
+ with(tf.Graph().as_default(), g =>
+ {
+ var x = tf.constant(2);
+ var y = tf.constant(1);
+ var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)),
+ () => tf.add(y, tf.constant(23)));
+ self.assertEquals(eval_scalar(z), 24);
+ });
}
[Ignore("Todo")]