diff --git a/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs
index 0de8bdeb..98ccbb06 100644
--- a/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs
+++ b/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs
@@ -27,10 +27,17 @@ namespace Tensorflow
return op.type == "Switch" || op.type == "RefSwitch";
}
+ ///
+ /// Return the control flow context for the output of an op.
+ ///
public static IControlFlowContext GetOutputContext(Operation op)
{
var ctxt = op._get_control_flow_context();
-
+ // Exit nodes usually have a control flow context, except in the case where the
+ // exit node was imported via import_graph_def (in which case no nodes have
+ // control flow contexts).
+ if (ctxt != null && IsLoopExit(op))
+ ctxt = ctxt.outer_context;
return ctxt;
}
}
diff --git a/test/TensorFlowNET.UnitTest/PythonTest.cs b/test/TensorFlowNET.UnitTest/PythonTest.cs
index c997c9bf..5af8bd52 100644
--- a/test/TensorFlowNET.UnitTest/PythonTest.cs
+++ b/test/TensorFlowNET.UnitTest/PythonTest.cs
@@ -5,6 +5,7 @@ using System.Linq;
using System.Text;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
+using Tensorflow.Util;
namespace TensorFlowNET.UnitTest
{
@@ -13,6 +14,15 @@ namespace TensorFlowNET.UnitTest
///
public class PythonTest : Python
{
+ #region python compatibility layer
+ protected PythonTest self { get => this; }
+ protected object None {
+ get { return null; }
+ }
+ #endregion
+
+ #region pytest assertions
+
public void assertItemsEqual(ICollection given, ICollection expected)
{
Assert.IsNotNull(expected);
@@ -20,20 +30,62 @@ namespace TensorFlowNET.UnitTest
var e = expected.OfType