From 04ffa46d7b7eb720ae9ed33d6ef4b60a44969951 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Wed, 10 Apr 2019 21:17:02 -0500 Subject: [PATCH] fix CondContext.AddValue with ops.control_dependencies #213 --- .../Operations/ControlFlows/CondContext.cs | 16 ++++++++-------- .../ControlFlows/ControlFlowContext.cs | 2 +- src/TensorFlowNET.Core/Operations/check_ops.cs | 3 ++- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs index 5fd25faa..385caf1c 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs @@ -88,14 +88,14 @@ namespace Tensorflow.Operations _values.Add(result.name); _external_values[result.name] = result; } - // TODO: how to do 'with' here?? - //with(ops.control_dependencies(null), ctrl => - //{ - var (r0, r1) = control_flow_ops._SwitchRefOrTensor(result, _pred); - result = new[]{r0, r1}[_branch]; - if (_outer_context != null) - _outer_context.AddInnerOp(result.op); - //}); + + with(ops.control_dependencies(null), ctrl => + { + var (r0, r1) = control_flow_ops._SwitchRefOrTensor(result, _pred); + result = new[] { r0, r1 }[_branch]; + if (_outer_context != null) + _outer_context.AddInnerOp(result.op); + }); result.op.graph.prevent_fetching(result.op); result.op._set_control_flow_context(this); diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs index b7cc911d..0556b526 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs @@ -22,7 +22,7 @@ namespace Tensorflow.Operations /// 4. A ControlFlowContext has _context_stack. /// Pushed and popped by ctxt.Enter() and ctxt.Exit() /// - public abstract class ControlFlowContext : IPython, IControlFlowContext + public abstract class ControlFlowContext : Python, IPython, IControlFlowContext { /// /// The predicate tensor in this branch diff --git a/src/TensorFlowNET.Core/Operations/check_ops.cs b/src/TensorFlowNET.Core/Operations/check_ops.cs index 8f395223..87e26f7b 100644 --- a/src/TensorFlowNET.Core/Operations/check_ops.cs +++ b/src/TensorFlowNET.Core/Operations/check_ops.cs @@ -35,7 +35,8 @@ namespace Tensorflow }; } - var condition = math_ops.reduce_all(gen_math_ops.equal(x, y)); + var eq = gen_math_ops.equal(x, y); + var condition = math_ops.reduce_all(eq); var x_static = tensor_util.constant_value(x); var y_static = tensor_util.constant_value(y); return control_flow_ops.Assert(condition, data);