From a5bf2f4f9fed7e00dbc6a68b71da333ecf2d054d Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 7 Nov 2019 07:51:44 -0600 Subject: [PATCH] add context_t.ExitResult for control_flow_ops.cond --- .../Operations/control_flow_ops.cs | 82 +++++++++---------- 1 file changed, 39 insertions(+), 43 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs index b8360939..ffa0675b 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs @@ -376,19 +376,9 @@ namespace Tensorflow { return tf_with(ops.name_scope(name, "cond", new { pred }), delegate { - // TODO: here a chunk of original code is missing - /* - with ops.name_scope(name, "cond", [pred]): - if context.executing_eagerly(): - if pred: - return _UnpackIfSingleton(true_fn()) - return _UnpackIfSingleton(false_fn()) - */ - // Add the Switch to the graph. var switch_result= @switch(pred, pred); - var p_2=switch_result[0]; - var p_1 = switch_result[1]; + var (p_2, p_1 )= (switch_result[0], switch_result[1]); var pivot_1 = array_ops.identity(p_1, name: "switch_t"); var pivot_2 = array_ops.identity(p_2, name: "switch_f"); pred = array_ops.identity(pred, name: "pred_id"); @@ -405,6 +395,7 @@ namespace Tensorflow { context_t.Enter(); (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); + context_t.ExitResult(new[] { res_t }); } finally { @@ -418,46 +409,36 @@ namespace Tensorflow { context_f.Enter(); (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); + context_f.ExitResult(new[] { res_f }); } finally { context_f.Exit(); } - //TODO: missing original code - //if not strict: - // orig_res_t = _UnpackIfSingleton(orig_res_t) - // orig_res_f = _UnpackIfSingleton(orig_res_f) - /* - # Check that the return values of the two branches have the same structure. - try: - nest.assert_same_structure(orig_res_t, orig_res_f) - except TypeError as e: - raise TypeError( - "Incompatible return types of true_fn and false_fn: {}".format(e)) - except ValueError as e: - raise ValueError( - "Incompatible return values of true_fn and false_fn: {}".format(e))*/ - var res_t_flat = new Tensor[] { res_t }; var res_f_flat = new Tensor[] { res_f }; - foreach(var (val_x, val_y) in zip(res_t_flat, res_f_flat)) - { - - } - var merges = zip(res_f_flat, res_t_flat) - .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })) - .Select(m => (Tensor)m) + .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })[0]) .ToArray(); - var merges2 = _convert_flows_to_tensorarrays(new ITensorOrTensorArray[] { (Tensor)orig_res_t }, merges); + if (orig_res_t is Tensor orig_res_tensor) + merges = _convert_flows_to_tensorarrays(new[] { orig_res_tensor }, merges) + .Select(x => x as Tensor) + .ToArray(); + else + { - ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); - ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); + } - return new Tensor(IntPtr.Zero); + if(context_t.outer_context == null) + { + ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); + ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); + } + + return merges[0]; }); } @@ -485,28 +466,43 @@ namespace Tensorflow var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1); context_t.Enter(); var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); + context_t.ExitResult(res_t); context_t.Exit(); // Build the graph for the false branch in a new context. var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0); context_f.Enter(); var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); + context_f.ExitResult(res_f); context_f.Exit(); var res_t_flat = res_t; var res_f_flat = res_f; var merges = zip(res_f_flat, res_t_flat) - .Select(pair => merge(new [] { pair.Item1, pair.Item2 })) - .Select(m => (Tensor)m) + .Select(pair => merge(new [] { pair.Item1, pair.Item2 })[0]) .ToArray(); - var merges2 = _convert_flows_to_tensorarrays(orig_res_t.Select(x => (ITensorOrTensorArray)x).ToArray(), merges); + if (orig_res_t is Tensor[] orig_res_tensor) + merges = _convert_flows_to_tensorarrays(orig_res_tensor, merges) + .Select(x => x as Tensor) + .ToArray(); + else if (orig_res_t is float[] orig_res_float) + { - ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); - ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); + } + else + { + + } + + if(context_t.outer_context == null) + { + ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); + ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); + } - return new[] { new Tensor(IntPtr.Zero) }; + return merges; }); }