From bd1e8531875fdf8d564f4820d40cb59e7d48f55c Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Mon, 15 Apr 2019 20:49:17 -0500 Subject: [PATCH] fix unit test evaluate. --- .../Operations/ControlFlows/CondContext.cs | 64 ++++++++++++++++- .../ControlFlows/ControlFlowContext.cs | 69 +++---------------- .../Operations/nn_impl.py.cs | 3 +- src/TensorFlowNET.Core/ops.py.cs | 2 + test/TensorFlowNET.UnitTest/PythonTest.cs | 5 +- 5 files changed, 74 insertions(+), 69 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs index 7b70a76a..47908e05 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs @@ -238,9 +238,67 @@ namespace Tensorflow.Operations return real_val; } - public override void AddInnerOp(Operation resultOp) - { - throw new NotImplementedException(); + protected override void _AddOpInternal(Operation op) + { + if (op.inputs.Length == 0) + { + //If we're in a while loop, remove any control inputs from outside the + // loop. + _RemoveExternalControlEdges(op); + if (!op.control_inputs.Any(input_op => OpInContext(input_op))) + op._add_control_input(_pivot.op); + } + else + { + // Make each input to 'op' available in this CondContext. If an input is + // already part of this context there's nothing to do, but if it's + // external, AddValue() will handle adding the appropriate Switch node and + // other bookkeeping. + for (int index = 0; index < op.inputs.Length; index++) + { + var x = op.inputs[index]; + Tensor real_x = null; + if (op.type == "Merge" && x.op.type == "NextIteration") + { + //# Edge case: if we're importing a while loop inside this CondContext, + //# AddValue() will not correctly handle the NextIteration inputs to + //# Merge node. The problem is that the NextIteration should also be + //# part of this context, but if we're importing it won't have been + //# processed and added to the context yet, so AddValue() will try to + //# add a Switch which results in an invalid graph. Instead, we use the + //# NextIteration input as-is here, and it will eventually be added to + //# the context via AddOp(). + real_x = x; + } + else + { + real_x = AddValue(x); + } + if (real_x != x) + op._update_input(index, real_x); + } + // Remove any external control dependency on this op. + _RemoveExternalControlEdges(op); + // TODO: implement below code dependencies + //if (op.graph._is_function(op.type) || op.type == "SymbolicGradient") + // op._add_control_input(_pivot.op); + } + + // Mark op's outputs as seen by this context and any outer contexts. + var output_names = op.outputs.Select(x => x.name).ToArray(); + IControlFlowContext ctxt = this; + while (ctxt != null) + { + foreach (var name in output_names) + ctxt.values.Add(name); + ctxt = ctxt.outer_context; + } + + if (_outer_context != null || !control_flow_ops.IsLoopExit(op)) + op.graph.prevent_fetching(op); + + if (_outer_context != null) + _outer_context.AddInnerOp(op); } public CondContextDef to_proto(string export_scope) diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs index 86452e50..6dd8e25e 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs @@ -119,9 +119,14 @@ namespace Tensorflow.Operations return null; } - public virtual void AddInnerOp(Operation resultOp) + /// + /// Notifies a scope about an operator added to an inner scope. + /// + /// + public virtual void AddInnerOp(Operation op) { - // to be overridden + if (_outer_context != null) + _outer_context.AddInnerOp(op); } protected HashSet _values = new HashSet(); @@ -131,68 +136,10 @@ namespace Tensorflow.Operations /// protected virtual void _AddOpInternal(Operation op) { - if (op.inputs.Length == 0) - { - //If we're in a while loop, remove any control inputs from outside the - // loop. - _RemoveExternalControlEdges(op); - if (!op.control_inputs.Any(input_op => OpInContext(input_op))) - op._add_control_input(_pivot.op); - } - else - { - // Make each input to 'op' available in this CondContext. If an input is - // already part of this context there's nothing to do, but if it's - // external, AddValue() will handle adding the appropriate Switch node and - // other bookkeeping. - for (int index = 0; index < op.inputs.Length; index++) - { - var x = op.inputs[index]; - Tensor real_x = null; - if (op.type == "Merge" && x.op.type == "NextIteration") - { - //# Edge case: if we're importing a while loop inside this CondContext, - //# AddValue() will not correctly handle the NextIteration inputs to - //# Merge node. The problem is that the NextIteration should also be - //# part of this context, but if we're importing it won't have been - //# processed and added to the context yet, so AddValue() will try to - //# add a Switch which results in an invalid graph. Instead, we use the - //# NextIteration input as-is here, and it will eventually be added to - //# the context via AddOp(). - real_x = x; - } - else - { - real_x = AddValue(x); - } - if (real_x != x) - op._update_input(index, real_x); - } - // Remove any external control dependency on this op. - _RemoveExternalControlEdges(op); - // TODO: implement below code dependencies - //if (op.graph._is_function(op.type) || op.type == "SymbolicGradient") - // op._add_control_input(_pivot.op); - } - // Mark op's outputs as seen by this context and any outer contexts. - var output_names = op.outputs.Select(x => x.name).ToArray(); - IControlFlowContext ctxt = this; - while (ctxt != null) - { - foreach(var name in output_names) - ctxt.values.Add(name); - ctxt = ctxt.outer_context; - } - - if (_outer_context != null || !control_flow_ops.IsLoopExit(op)) - op.graph.prevent_fetching(op); - - if (_outer_context != null) - _outer_context.AddInnerOp(op); } - private bool OpInContext(Operation op) + protected bool OpInContext(Operation op) { return IsContainingContext(op._get_control_flow_context(), this); } diff --git a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs index c5a86d31..2d050513 100644 --- a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs +++ b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs @@ -23,7 +23,8 @@ namespace Tensorflow return with(ops.name_scope(name, "l2_normalize", new { x }), scope => { x = ops.convert_to_tensor(x, name: "x"); - var square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims: true); + var sq = math_ops.square(x); + var square_sum = math_ops.reduce_sum(sq, axis, keepdims: true); var x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon)); return math_ops.multiply(x, x_inv_norm, name: name); }); diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index c147e1b5..63a2868a 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -360,6 +360,8 @@ namespace Tensorflow /// The default `Session` being used in the current thread. public static Session get_default_session() { + if (tf.defaultSession == null) + tf.defaultSession = tf.Session(); return tf.defaultSession; } diff --git a/test/TensorFlowNET.UnitTest/PythonTest.cs b/test/TensorFlowNET.UnitTest/PythonTest.cs index 9396da58..dcd8e80b 100644 --- a/test/TensorFlowNET.UnitTest/PythonTest.cs +++ b/test/TensorFlowNET.UnitTest/PythonTest.cs @@ -143,10 +143,7 @@ namespace TensorFlowNET.UnitTest // return self._eval_helper(tensors) // else: { - var sess = ops.get_default_session(); - if (sess == null) - sess = self.session(); - with(sess, s => + with(ops.get_default_session(), s => { var ndarray=tensor.eval(); if (typeof(T) == typeof(double))