Browse Source

add context_t.ExitResult for control_flow_ops.cond

tags/v0.13
Oceania2018 6 years ago
parent
commit
a5bf2f4f9f
1 changed files with 39 additions and 43 deletions
  1. +39
    -43
      src/TensorFlowNET.Core/Operations/control_flow_ops.cs

+ 39
- 43
src/TensorFlowNET.Core/Operations/control_flow_ops.cs View File

@@ -376,19 +376,9 @@ namespace Tensorflow
{
return tf_with(ops.name_scope(name, "cond", new { pred }), delegate
{
// TODO: here a chunk of original code is missing
/*
with ops.name_scope(name, "cond", [pred]):
if context.executing_eagerly():
if pred:
return _UnpackIfSingleton(true_fn())
return _UnpackIfSingleton(false_fn())
*/

// Add the Switch to the graph.
var switch_result= @switch(pred, pred);
var p_2=switch_result[0];
var p_1 = switch_result[1];
var (p_2, p_1 )= (switch_result[0], switch_result[1]);
var pivot_1 = array_ops.identity(p_1, name: "switch_t");
var pivot_2 = array_ops.identity(p_2, name: "switch_f");
pred = array_ops.identity(pred, name: "pred_id");
@@ -405,6 +395,7 @@ namespace Tensorflow
{
context_t.Enter();
(orig_res_t, res_t) = context_t.BuildCondBranch(true_fn);
context_t.ExitResult(new[] { res_t });
}
finally
{
@@ -418,46 +409,36 @@ namespace Tensorflow
{
context_f.Enter();
(orig_res_f, res_f) = context_f.BuildCondBranch(false_fn);
context_f.ExitResult(new[] { res_f });
}
finally
{
context_f.Exit();
}

//TODO: missing original code
//if not strict:
// orig_res_t = _UnpackIfSingleton(orig_res_t)
// orig_res_f = _UnpackIfSingleton(orig_res_f)
/*
# Check that the return values of the two branches have the same structure.
try:
nest.assert_same_structure(orig_res_t, orig_res_f)
except TypeError as e:
raise TypeError(
"Incompatible return types of true_fn and false_fn: {}".format(e))
except ValueError as e:
raise ValueError(
"Incompatible return values of true_fn and false_fn: {}".format(e))*/

var res_t_flat = new Tensor[] { res_t };
var res_f_flat = new Tensor[] { res_f };

foreach(var (val_x, val_y) in zip(res_t_flat, res_f_flat))
{

}
var merges = zip(res_f_flat, res_t_flat)
.Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 }))
.Select(m => (Tensor)m)
.Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })[0])
.ToArray();

var merges2 = _convert_flows_to_tensorarrays(new ITensorOrTensorArray[] { (Tensor)orig_res_t }, merges);
if (orig_res_t is Tensor orig_res_tensor)
merges = _convert_flows_to_tensorarrays(new[] { orig_res_tensor }, merges)
.Select(x => x as Tensor)
.ToArray();
else
{

ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t);
ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f);
}

return new Tensor(IntPtr.Zero);
if(context_t.outer_context == null)
{
ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t);
ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f);
}

return merges[0];
});
}

@@ -485,28 +466,43 @@ namespace Tensorflow
var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1);
context_t.Enter();
var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn);
context_t.ExitResult(res_t);
context_t.Exit();

// Build the graph for the false branch in a new context.
var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0);
context_f.Enter();
var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn);
context_f.ExitResult(res_f);
context_f.Exit();

var res_t_flat = res_t;
var res_f_flat = res_f;

var merges = zip(res_f_flat, res_t_flat)
.Select(pair => merge(new [] { pair.Item1, pair.Item2 }))
.Select(m => (Tensor)m)
.Select(pair => merge(new [] { pair.Item1, pair.Item2 })[0])
.ToArray();

var merges2 = _convert_flows_to_tensorarrays(orig_res_t.Select(x => (ITensorOrTensorArray)x).ToArray(), merges);
if (orig_res_t is Tensor[] orig_res_tensor)
merges = _convert_flows_to_tensorarrays(orig_res_tensor, merges)
.Select(x => x as Tensor)
.ToArray();
else if (orig_res_t is float[] orig_res_float)
{

ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t);
ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f);
}
else
{

}

if(context_t.outer_context == null)
{
ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t);
ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f);
}

return new[] { new Tensor(IntPtr.Zero) };
return merges;
});
}



Loading…
Cancel
Save