| @@ -376,19 +376,9 @@ namespace Tensorflow | |||
| { | |||
| return tf_with(ops.name_scope(name, "cond", new { pred }), delegate | |||
| { | |||
| // TODO: here a chunk of original code is missing | |||
| /* | |||
| with ops.name_scope(name, "cond", [pred]): | |||
| if context.executing_eagerly(): | |||
| if pred: | |||
| return _UnpackIfSingleton(true_fn()) | |||
| return _UnpackIfSingleton(false_fn()) | |||
| */ | |||
| // Add the Switch to the graph. | |||
| var switch_result= @switch(pred, pred); | |||
| var p_2=switch_result[0]; | |||
| var p_1 = switch_result[1]; | |||
| var (p_2, p_1 )= (switch_result[0], switch_result[1]); | |||
| var pivot_1 = array_ops.identity(p_1, name: "switch_t"); | |||
| var pivot_2 = array_ops.identity(p_2, name: "switch_f"); | |||
| pred = array_ops.identity(pred, name: "pred_id"); | |||
| @@ -405,6 +395,7 @@ namespace Tensorflow | |||
| { | |||
| context_t.Enter(); | |||
| (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); | |||
| context_t.ExitResult(new[] { res_t }); | |||
| } | |||
| finally | |||
| { | |||
| @@ -418,46 +409,36 @@ namespace Tensorflow | |||
| { | |||
| context_f.Enter(); | |||
| (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); | |||
| context_f.ExitResult(new[] { res_f }); | |||
| } | |||
| finally | |||
| { | |||
| context_f.Exit(); | |||
| } | |||
| //TODO: missing original code | |||
| //if not strict: | |||
| // orig_res_t = _UnpackIfSingleton(orig_res_t) | |||
| // orig_res_f = _UnpackIfSingleton(orig_res_f) | |||
| /* | |||
| # Check that the return values of the two branches have the same structure. | |||
| try: | |||
| nest.assert_same_structure(orig_res_t, orig_res_f) | |||
| except TypeError as e: | |||
| raise TypeError( | |||
| "Incompatible return types of true_fn and false_fn: {}".format(e)) | |||
| except ValueError as e: | |||
| raise ValueError( | |||
| "Incompatible return values of true_fn and false_fn: {}".format(e))*/ | |||
| var res_t_flat = new Tensor[] { res_t }; | |||
| var res_f_flat = new Tensor[] { res_f }; | |||
| foreach(var (val_x, val_y) in zip(res_t_flat, res_f_flat)) | |||
| { | |||
| } | |||
| var merges = zip(res_f_flat, res_t_flat) | |||
| .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })) | |||
| .Select(m => (Tensor)m) | |||
| .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })[0]) | |||
| .ToArray(); | |||
| var merges2 = _convert_flows_to_tensorarrays(new ITensorOrTensorArray[] { (Tensor)orig_res_t }, merges); | |||
| if (orig_res_t is Tensor orig_res_tensor) | |||
| merges = _convert_flows_to_tensorarrays(new[] { orig_res_tensor }, merges) | |||
| .Select(x => x as Tensor) | |||
| .ToArray(); | |||
| else | |||
| { | |||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); | |||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); | |||
| } | |||
| return new Tensor(IntPtr.Zero); | |||
| if(context_t.outer_context == null) | |||
| { | |||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); | |||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); | |||
| } | |||
| return merges[0]; | |||
| }); | |||
| } | |||
| @@ -485,28 +466,43 @@ namespace Tensorflow | |||
| var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1); | |||
| context_t.Enter(); | |||
| var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); | |||
| context_t.ExitResult(res_t); | |||
| context_t.Exit(); | |||
| // Build the graph for the false branch in a new context. | |||
| var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0); | |||
| context_f.Enter(); | |||
| var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); | |||
| context_f.ExitResult(res_f); | |||
| context_f.Exit(); | |||
| var res_t_flat = res_t; | |||
| var res_f_flat = res_f; | |||
| var merges = zip(res_f_flat, res_t_flat) | |||
| .Select(pair => merge(new [] { pair.Item1, pair.Item2 })) | |||
| .Select(m => (Tensor)m) | |||
| .Select(pair => merge(new [] { pair.Item1, pair.Item2 })[0]) | |||
| .ToArray(); | |||
| var merges2 = _convert_flows_to_tensorarrays(orig_res_t.Select(x => (ITensorOrTensorArray)x).ToArray(), merges); | |||
| if (orig_res_t is Tensor[] orig_res_tensor) | |||
| merges = _convert_flows_to_tensorarrays(orig_res_tensor, merges) | |||
| .Select(x => x as Tensor) | |||
| .ToArray(); | |||
| else if (orig_res_t is float[] orig_res_float) | |||
| { | |||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); | |||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); | |||
| } | |||
| else | |||
| { | |||
| } | |||
| if(context_t.outer_context == null) | |||
| { | |||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); | |||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); | |||
| } | |||
| return new[] { new Tensor(IntPtr.Zero) }; | |||
| return merges; | |||
| }); | |||
| } | |||