diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
index 7b70a76a..47908e05 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
@@ -238,9 +238,67 @@ namespace Tensorflow.Operations
return real_val;
}
- public override void AddInnerOp(Operation resultOp)
- {
- throw new NotImplementedException();
+ protected override void _AddOpInternal(Operation op)
+ {
+ if (op.inputs.Length == 0)
+ {
+ //If we're in a while loop, remove any control inputs from outside the
+ // loop.
+ _RemoveExternalControlEdges(op);
+ if (!op.control_inputs.Any(input_op => OpInContext(input_op)))
+ op._add_control_input(_pivot.op);
+ }
+ else
+ {
+ // Make each input to 'op' available in this CondContext. If an input is
+ // already part of this context there's nothing to do, but if it's
+ // external, AddValue() will handle adding the appropriate Switch node and
+ // other bookkeeping.
+ for (int index = 0; index < op.inputs.Length; index++)
+ {
+ var x = op.inputs[index];
+ Tensor real_x = null;
+ if (op.type == "Merge" && x.op.type == "NextIteration")
+ {
+ //# Edge case: if we're importing a while loop inside this CondContext,
+ //# AddValue() will not correctly handle the NextIteration inputs to
+ //# Merge node. The problem is that the NextIteration should also be
+ //# part of this context, but if we're importing it won't have been
+ //# processed and added to the context yet, so AddValue() will try to
+ //# add a Switch which results in an invalid graph. Instead, we use the
+ //# NextIteration input as-is here, and it will eventually be added to
+ //# the context via AddOp().
+ real_x = x;
+ }
+ else
+ {
+ real_x = AddValue(x);
+ }
+ if (real_x != x)
+ op._update_input(index, real_x);
+ }
+ // Remove any external control dependency on this op.
+ _RemoveExternalControlEdges(op);
+ // TODO: implement below code dependencies
+ //if (op.graph._is_function(op.type) || op.type == "SymbolicGradient")
+ // op._add_control_input(_pivot.op);
+ }
+
+ // Mark op's outputs as seen by this context and any outer contexts.
+ var output_names = op.outputs.Select(x => x.name).ToArray();
+ IControlFlowContext ctxt = this;
+ while (ctxt != null)
+ {
+ foreach (var name in output_names)
+ ctxt.values.Add(name);
+ ctxt = ctxt.outer_context;
+ }
+
+ if (_outer_context != null || !control_flow_ops.IsLoopExit(op))
+ op.graph.prevent_fetching(op);
+
+ if (_outer_context != null)
+ _outer_context.AddInnerOp(op);
}
public CondContextDef to_proto(string export_scope)
diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
index 86452e50..6dd8e25e 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
@@ -119,9 +119,14 @@ namespace Tensorflow.Operations
return null;
}
- public virtual void AddInnerOp(Operation resultOp)
+ ///
+ /// Notifies a scope about an operator added to an inner scope.
+ ///
+ ///
+ public virtual void AddInnerOp(Operation op)
{
- // to be overridden
+ if (_outer_context != null)
+ _outer_context.AddInnerOp(op);
}
protected HashSet _values = new HashSet();
@@ -131,68 +136,10 @@ namespace Tensorflow.Operations
///
protected virtual void _AddOpInternal(Operation op)
{
- if (op.inputs.Length == 0)
- {
- //If we're in a while loop, remove any control inputs from outside the
- // loop.
- _RemoveExternalControlEdges(op);
- if (!op.control_inputs.Any(input_op => OpInContext(input_op)))
- op._add_control_input(_pivot.op);
- }
- else
- {
- // Make each input to 'op' available in this CondContext. If an input is
- // already part of this context there's nothing to do, but if it's
- // external, AddValue() will handle adding the appropriate Switch node and
- // other bookkeeping.
- for (int index = 0; index < op.inputs.Length; index++)
- {
- var x = op.inputs[index];
- Tensor real_x = null;
- if (op.type == "Merge" && x.op.type == "NextIteration")
- {
- //# Edge case: if we're importing a while loop inside this CondContext,
- //# AddValue() will not correctly handle the NextIteration inputs to
- //# Merge node. The problem is that the NextIteration should also be
- //# part of this context, but if we're importing it won't have been
- //# processed and added to the context yet, so AddValue() will try to
- //# add a Switch which results in an invalid graph. Instead, we use the
- //# NextIteration input as-is here, and it will eventually be added to
- //# the context via AddOp().
- real_x = x;
- }
- else
- {
- real_x = AddValue(x);
- }
- if (real_x != x)
- op._update_input(index, real_x);
- }
- // Remove any external control dependency on this op.
- _RemoveExternalControlEdges(op);
- // TODO: implement below code dependencies
- //if (op.graph._is_function(op.type) || op.type == "SymbolicGradient")
- // op._add_control_input(_pivot.op);
- }
- // Mark op's outputs as seen by this context and any outer contexts.
- var output_names = op.outputs.Select(x => x.name).ToArray();
- IControlFlowContext ctxt = this;
- while (ctxt != null)
- {
- foreach(var name in output_names)
- ctxt.values.Add(name);
- ctxt = ctxt.outer_context;
- }
-
- if (_outer_context != null || !control_flow_ops.IsLoopExit(op))
- op.graph.prevent_fetching(op);
-
- if (_outer_context != null)
- _outer_context.AddInnerOp(op);
}
- private bool OpInContext(Operation op)
+ protected bool OpInContext(Operation op)
{
return IsContainingContext(op._get_control_flow_context(), this);
}
diff --git a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs
index c5a86d31..2d050513 100644
--- a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs
+++ b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs
@@ -23,7 +23,8 @@ namespace Tensorflow
return with(ops.name_scope(name, "l2_normalize", new { x }), scope =>
{
x = ops.convert_to_tensor(x, name: "x");
- var square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims: true);
+ var sq = math_ops.square(x);
+ var square_sum = math_ops.reduce_sum(sq, axis, keepdims: true);
var x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon));
return math_ops.multiply(x, x_inv_norm, name: name);
});
diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs
index c147e1b5..63a2868a 100644
--- a/src/TensorFlowNET.Core/ops.py.cs
+++ b/src/TensorFlowNET.Core/ops.py.cs
@@ -360,6 +360,8 @@ namespace Tensorflow
/// The default `Session` being used in the current thread.
public static Session get_default_session()
{
+ if (tf.defaultSession == null)
+ tf.defaultSession = tf.Session();
return tf.defaultSession;
}
diff --git a/test/TensorFlowNET.UnitTest/PythonTest.cs b/test/TensorFlowNET.UnitTest/PythonTest.cs
index 9396da58..dcd8e80b 100644
--- a/test/TensorFlowNET.UnitTest/PythonTest.cs
+++ b/test/TensorFlowNET.UnitTest/PythonTest.cs
@@ -143,10 +143,7 @@ namespace TensorFlowNET.UnitTest
// return self._eval_helper(tensors)
// else:
{
- var sess = ops.get_default_session();
- if (sess == null)
- sess = self.session();
- with(sess, s =>
+ with(ops.get_default_session(), s =>
{
var ndarray=tensor.eval();
if (typeof(T) == typeof(double))