diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs index e43dd913..a610f4e7 100644 --- a/src/TensorFlowNET.Core/Sessions/Session.cs +++ b/src/TensorFlowNET.Core/Sessions/Session.cs @@ -38,6 +38,12 @@ namespace Tensorflow Status.Check(true); } + public Session as_default() + { + tf.defaultSession = this; + return this; + } + public static Session LoadFromSavedModel(string path) { var graph = c_api.TF_NewGraph(); diff --git a/test/TensorFlowNET.UnitTest/PythonTest.cs b/test/TensorFlowNET.UnitTest/PythonTest.cs index 5d9bb374..508a0a81 100644 --- a/test/TensorFlowNET.UnitTest/PythonTest.cs +++ b/test/TensorFlowNET.UnitTest/PythonTest.cs @@ -132,23 +132,66 @@ namespace TensorFlowNET.UnitTest } /// - /// Evaluates tensors and returns numpy values. + /// Evaluates tensors and returns a dictionary of {name:result, ...}. /// A Tensor or a nested list/tuple of Tensors. /// - /// tensors numpy values. - public object evaluate(params Tensor[] tensors) + public Dictionary evaluate(params Tensor[] tensors) { + var results = new Dictionary(); // if context.executing_eagerly(): // return self._eval_helper(tensors) // else: { var sess = ops.get_default_session(); - if (sess == None) - with(self.session(), s => sess = s); - return sess.run(tensors); + if (sess == null) + sess = self.session(); + + with(sess, s => + { + foreach (var t in tensors) + results[t.name] = t.eval(); + }); + return results; + } + } + + public NDArray evaluate(Tensor tensor) + { + NDArray result = null; + // if context.executing_eagerly(): + // return self._eval_helper(tensors) + // else: + { + var sess = ops.get_default_session(); + if (sess == null) + sess = self.session(); + with(sess, s => + { + result = tensor.eval(); + }); + return result; + } + } + + public object eval_scalar(Tensor tensor) + { + NDArray result = null; + // if context.executing_eagerly(): + // return self._eval_helper(tensors) + // else: + { + var sess = ops.get_default_session(); + if (sess == null) + sess = self.session(); + with(sess, s => + { + result = tensor.eval(); + }); + return result.Array.GetValue(0); } } + //Returns a TensorFlow Session for use in executing tests. public Session session(Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false) { @@ -188,16 +231,11 @@ namespace TensorFlowNET.UnitTest //if (context.executing_eagerly()) // yield None //else - { - with(self._create_session(graph, config, force_gpu), sess => - { - with(self._constrain_devices_and_set_default(sess, use_gpu, force_gpu), (x) => - { - s = sess; - }); - }); - } - return s; + //{ + s = self._create_session(graph, config, force_gpu); + self._constrain_devices_and_set_default(s, use_gpu, force_gpu); + //} + return s.as_default(); } private IPython _constrain_devices_and_set_default(Session sess, bool useGpu, bool forceGpu) diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs b/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs index 85908baf..0b5f0879 100644 --- a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs +++ b/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs @@ -13,53 +13,29 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test [TestMethod] public void testCondTrue() { - var x = tf.constant(2); - var y = tf.constant(5); - var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)), - () => tf.add(y, tf.constant(23))); - self.assertEquals(self.evaluate(z), 34); + with(tf.Graph().as_default(), g => + { + var x = tf.constant(2); + var y = tf.constant(5); + var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)), + () => tf.add(y, tf.constant(23))); + //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); + self.assertEquals(eval_scalar(z), 34); + }); } - [Ignore("Todo")] + [Ignore("This Test Fails due to missing edges in the graph!")] [TestMethod] public void testCondFalse() { - // def testCondFalse(self): - // x = constant_op.constant(2) - // y = constant_op.constant(1) - // z = control_flow_ops.cond( - // math_ops.less( - // x, - // y), lambda: math_ops.multiply(x, 17), lambda: math_ops.add(y, 23)) - // self.assertEquals(self.evaluate(z), 24) - } - - [Ignore("Todo")] - [TestMethod] - public void testCondTrueLegacy() - { - // def testCondTrueLegacy(self): - // x = constant_op.constant(2) - // y = constant_op.constant(5) - // z = control_flow_ops.cond( - // math_ops.less(x, y), - // fn1=lambda: math_ops.multiply(x, 17), - // fn2=lambda: math_ops.add(y, 23)) - // self.assertEquals(self.evaluate(z), 34) - } - - [Ignore("Todo")] - [TestMethod] - public void testCondFalseLegacy() - { - // def testCondFalseLegacy(self): - // x = constant_op.constant(2) - // y = constant_op.constant(1) - // z = control_flow_ops.cond( - // math_ops.less(x, y), - // fn1=lambda: math_ops.multiply(x, 17), - // fn2=lambda: math_ops.add(y, 23)) - // self.assertEquals(self.evaluate(z), 24) + with(tf.Graph().as_default(), g => + { + var x = tf.constant(2); + var y = tf.constant(1); + var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)), + () => tf.add(y, tf.constant(23))); + self.assertEquals(eval_scalar(z), 24); + }); } [Ignore("Todo")]