| @@ -376,19 +376,9 @@ namespace Tensorflow | |||||
| { | { | ||||
| return tf_with(ops.name_scope(name, "cond", new { pred }), delegate | return tf_with(ops.name_scope(name, "cond", new { pred }), delegate | ||||
| { | { | ||||
| // TODO: here a chunk of original code is missing | |||||
| /* | |||||
| with ops.name_scope(name, "cond", [pred]): | |||||
| if context.executing_eagerly(): | |||||
| if pred: | |||||
| return _UnpackIfSingleton(true_fn()) | |||||
| return _UnpackIfSingleton(false_fn()) | |||||
| */ | |||||
| // Add the Switch to the graph. | // Add the Switch to the graph. | ||||
| var switch_result= @switch(pred, pred); | var switch_result= @switch(pred, pred); | ||||
| var p_2=switch_result[0]; | |||||
| var p_1 = switch_result[1]; | |||||
| var (p_2, p_1 )= (switch_result[0], switch_result[1]); | |||||
| var pivot_1 = array_ops.identity(p_1, name: "switch_t"); | var pivot_1 = array_ops.identity(p_1, name: "switch_t"); | ||||
| var pivot_2 = array_ops.identity(p_2, name: "switch_f"); | var pivot_2 = array_ops.identity(p_2, name: "switch_f"); | ||||
| pred = array_ops.identity(pred, name: "pred_id"); | pred = array_ops.identity(pred, name: "pred_id"); | ||||
| @@ -405,6 +395,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| context_t.Enter(); | context_t.Enter(); | ||||
| (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); | (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); | ||||
| context_t.ExitResult(new[] { res_t }); | |||||
| } | } | ||||
| finally | finally | ||||
| { | { | ||||
| @@ -418,46 +409,36 @@ namespace Tensorflow | |||||
| { | { | ||||
| context_f.Enter(); | context_f.Enter(); | ||||
| (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); | (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); | ||||
| context_f.ExitResult(new[] { res_f }); | |||||
| } | } | ||||
| finally | finally | ||||
| { | { | ||||
| context_f.Exit(); | context_f.Exit(); | ||||
| } | } | ||||
| //TODO: missing original code | |||||
| //if not strict: | |||||
| // orig_res_t = _UnpackIfSingleton(orig_res_t) | |||||
| // orig_res_f = _UnpackIfSingleton(orig_res_f) | |||||
| /* | |||||
| # Check that the return values of the two branches have the same structure. | |||||
| try: | |||||
| nest.assert_same_structure(orig_res_t, orig_res_f) | |||||
| except TypeError as e: | |||||
| raise TypeError( | |||||
| "Incompatible return types of true_fn and false_fn: {}".format(e)) | |||||
| except ValueError as e: | |||||
| raise ValueError( | |||||
| "Incompatible return values of true_fn and false_fn: {}".format(e))*/ | |||||
| var res_t_flat = new Tensor[] { res_t }; | var res_t_flat = new Tensor[] { res_t }; | ||||
| var res_f_flat = new Tensor[] { res_f }; | var res_f_flat = new Tensor[] { res_f }; | ||||
| foreach(var (val_x, val_y) in zip(res_t_flat, res_f_flat)) | |||||
| { | |||||
| } | |||||
| var merges = zip(res_f_flat, res_t_flat) | var merges = zip(res_f_flat, res_t_flat) | ||||
| .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })) | |||||
| .Select(m => (Tensor)m) | |||||
| .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })[0]) | |||||
| .ToArray(); | .ToArray(); | ||||
| var merges2 = _convert_flows_to_tensorarrays(new ITensorOrTensorArray[] { (Tensor)orig_res_t }, merges); | |||||
| if (orig_res_t is Tensor orig_res_tensor) | |||||
| merges = _convert_flows_to_tensorarrays(new[] { orig_res_tensor }, merges) | |||||
| .Select(x => x as Tensor) | |||||
| .ToArray(); | |||||
| else | |||||
| { | |||||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); | |||||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); | |||||
| } | |||||
| return new Tensor(IntPtr.Zero); | |||||
| if(context_t.outer_context == null) | |||||
| { | |||||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); | |||||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); | |||||
| } | |||||
| return merges[0]; | |||||
| }); | }); | ||||
| } | } | ||||
| @@ -485,28 +466,43 @@ namespace Tensorflow | |||||
| var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1); | var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1); | ||||
| context_t.Enter(); | context_t.Enter(); | ||||
| var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); | var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); | ||||
| context_t.ExitResult(res_t); | |||||
| context_t.Exit(); | context_t.Exit(); | ||||
| // Build the graph for the false branch in a new context. | // Build the graph for the false branch in a new context. | ||||
| var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0); | var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0); | ||||
| context_f.Enter(); | context_f.Enter(); | ||||
| var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); | var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); | ||||
| context_f.ExitResult(res_f); | |||||
| context_f.Exit(); | context_f.Exit(); | ||||
| var res_t_flat = res_t; | var res_t_flat = res_t; | ||||
| var res_f_flat = res_f; | var res_f_flat = res_f; | ||||
| var merges = zip(res_f_flat, res_t_flat) | var merges = zip(res_f_flat, res_t_flat) | ||||
| .Select(pair => merge(new [] { pair.Item1, pair.Item2 })) | |||||
| .Select(m => (Tensor)m) | |||||
| .Select(pair => merge(new [] { pair.Item1, pair.Item2 })[0]) | |||||
| .ToArray(); | .ToArray(); | ||||
| var merges2 = _convert_flows_to_tensorarrays(orig_res_t.Select(x => (ITensorOrTensorArray)x).ToArray(), merges); | |||||
| if (orig_res_t is Tensor[] orig_res_tensor) | |||||
| merges = _convert_flows_to_tensorarrays(orig_res_tensor, merges) | |||||
| .Select(x => x as Tensor) | |||||
| .ToArray(); | |||||
| else if (orig_res_t is float[] orig_res_float) | |||||
| { | |||||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); | |||||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); | |||||
| } | |||||
| else | |||||
| { | |||||
| } | |||||
| if(context_t.outer_context == null) | |||||
| { | |||||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); | |||||
| ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); | |||||
| } | |||||
| return new[] { new Tensor(IntPtr.Zero) }; | |||||
| return merges; | |||||
| }); | }); | ||||
| } | } | ||||