|
|
|
@@ -71,6 +71,8 @@ namespace Tensorflow.Operations |
|
|
|
string name) |
|
|
|
{ |
|
|
|
_name = ops.get_default_graph().unique_name(name); |
|
|
|
_maximum_iterations = maximum_iterations; |
|
|
|
_parallel_iterations = parallel_iterations; |
|
|
|
_back_prop = back_prop; |
|
|
|
_swap_memory = swap_memory; |
|
|
|
_loop_exits = new List<Tensor>(); |
|
|
|
@@ -107,18 +109,27 @@ namespace Tensorflow.Operations |
|
|
|
/// <summary> |
|
|
|
/// Add the loop termination condition and body to the graph. |
|
|
|
/// </summary> |
|
|
|
internal Tensor[] BuildLoop<TItem>(Func<Tensor, TItem, Tensor> pred, |
|
|
|
Func<Tensor, TItem, LoopVar<TItem>> body, |
|
|
|
internal Tensor[] BuildLoop<TItem>(Func<LoopVar<TItem>, Tensor> pred, |
|
|
|
Func<LoopVar<TItem>, LoopVar<TItem>> body, |
|
|
|
LoopVar<TItem> loop_vars, |
|
|
|
TensorShape shape_invariants, |
|
|
|
TensorShape[] shape_invariants, |
|
|
|
bool return_same_structure) |
|
|
|
{ |
|
|
|
// Keep original_loop_vars to identify which are TensorArrays |
|
|
|
var original_loop_vars = loop_vars; |
|
|
|
// Convert TensorArrays to their flow variables |
|
|
|
var loop_vars_tensors = nest.flatten2(loop_vars) |
|
|
|
.Select(x => _convert_tensorarray_to_flow(x)) |
|
|
|
.ToArray(); |
|
|
|
|
|
|
|
if (shape_invariants == null) |
|
|
|
shape_invariants = loop_vars_tensors |
|
|
|
.Select(x => _get_shape_invariant(x as Tensor)) |
|
|
|
.ToArray(); |
|
|
|
|
|
|
|
Enter(); |
|
|
|
var(original_body_result, exit_vars) = _BuildLoop( |
|
|
|
pred, body, original_loop_vars, loop_vars, shape_invariants); |
|
|
|
pred, body, original_loop_vars, loop_vars_tensors, shape_invariants); |
|
|
|
Exit(); |
|
|
|
|
|
|
|
var flat_result = original_body_result; |
|
|
|
@@ -131,7 +142,7 @@ namespace Tensorflow.Operations |
|
|
|
return packed_exit_vars as Tensor[]; |
|
|
|
} |
|
|
|
|
|
|
|
private Tensor _convert_tensorarray_to_flow<TItem>(TItem tensor_or_tensor_array) |
|
|
|
private Tensor _convert_tensorarray_to_flow(object tensor_or_tensor_array) |
|
|
|
{ |
|
|
|
if (tensor_or_tensor_array is TensorArray tensor_array) |
|
|
|
return tensor_array.flow; |
|
|
|
@@ -141,97 +152,116 @@ namespace Tensorflow.Operations |
|
|
|
throw new NotImplementedException("_convert_tensorarray_to_flow"); |
|
|
|
} |
|
|
|
|
|
|
|
private (Tensor[], Tensor[]) _BuildLoop<TItem>(Func<Tensor, TItem, Tensor> pred, |
|
|
|
Func<Tensor, TItem, LoopVar<TItem>> body, |
|
|
|
LoopVar<TItem> original_loop_vars, |
|
|
|
LoopVar<TItem> loop_vars, |
|
|
|
TensorShape shape_invariants) |
|
|
|
private TensorShape _get_shape_invariant(Tensor var, int[] shape = null) |
|
|
|
{ |
|
|
|
var flat_loop_vars = original_loop_vars; |
|
|
|
return var.TensorShape; |
|
|
|
} |
|
|
|
|
|
|
|
// Convert TensorArrays to their flow variables |
|
|
|
var loop_vars_tensor = nest.map_structure( |
|
|
|
_convert_tensorarray_to_flow, |
|
|
|
nest.flatten2(loop_vars)); |
|
|
|
/// <summary> |
|
|
|
/// Add the loop termination condition and body to the graph. |
|
|
|
/// </summary> |
|
|
|
/// <typeparam name="TItem"></typeparam> |
|
|
|
/// <param name="pred"></param> |
|
|
|
/// <param name="body"></param> |
|
|
|
/// <param name="original_loop_vars"></param> |
|
|
|
/// <param name="loop_vars"></param> |
|
|
|
/// <param name="shape_invariants"></param> |
|
|
|
/// <returns></returns> |
|
|
|
private (Tensor[], Tensor[]) _BuildLoop<TItem>(Func<LoopVar<TItem>, Tensor> pred, |
|
|
|
Func<LoopVar<TItem>, LoopVar<TItem>> body, |
|
|
|
LoopVar<TItem> original_loop_vars, |
|
|
|
Tensor[] loop_vars, |
|
|
|
TensorShape[] shape_invariants) |
|
|
|
{ |
|
|
|
var flat_loop_vars = nest.flatten2(original_loop_vars) |
|
|
|
.Select(x => (ITensorOrTensorArray)x) |
|
|
|
.ToArray(); |
|
|
|
|
|
|
|
// Let the context know the loop variables so the loop variables |
|
|
|
// would be added in the outer contexts properly. |
|
|
|
if (loop_vars is Tensor[] real_vars) |
|
|
|
_InitializeValues(loop_vars); |
|
|
|
var real_vars = loop_vars; |
|
|
|
Tensor[] enter_vars = null; |
|
|
|
tf_with(ops.control_dependencies(null), delegate |
|
|
|
{ |
|
|
|
_InitializeValues(real_vars); |
|
|
|
Tensor[] enter_vars = null; |
|
|
|
tf_with(ops.control_dependencies(null), delegate |
|
|
|
{ |
|
|
|
enter_vars = real_vars.Select(x => _Enter(x, |
|
|
|
_name, |
|
|
|
is_constant: false, |
|
|
|
parallel_iterations: _parallel_iterations, |
|
|
|
use_input_shape: shape_invariants == null)) |
|
|
|
.ToArray(); |
|
|
|
|
|
|
|
foreach (var x in enter_vars) |
|
|
|
{ |
|
|
|
x.graph.prevent_feeding(x); |
|
|
|
if (_outer_context != null) |
|
|
|
_outer_context.AddInnerOp(x.op); |
|
|
|
} |
|
|
|
}); |
|
|
|
|
|
|
|
// Finds the closest enclosing non-None control pivot. |
|
|
|
var outer_context = _outer_context; |
|
|
|
while (outer_context != null) |
|
|
|
enter_vars = real_vars.Select(x => _Enter(x, |
|
|
|
_name, |
|
|
|
is_constant: false, |
|
|
|
parallel_iterations: _parallel_iterations, |
|
|
|
use_input_shape: shape_invariants == null)) |
|
|
|
.ToArray(); |
|
|
|
|
|
|
|
foreach (var x in enter_vars) |
|
|
|
{ |
|
|
|
|
|
|
|
x.graph.prevent_feeding(x); |
|
|
|
if (_outer_context != null) |
|
|
|
_outer_context.AddInnerOp(x.op); |
|
|
|
} |
|
|
|
}); |
|
|
|
|
|
|
|
_SetShapeInvariants(real_vars, enter_vars, shape_invariants); |
|
|
|
|
|
|
|
// Fix the control inputs and control flow context of these enter ops. |
|
|
|
_FixControlInputsAndContext(enter_vars); |
|
|
|
_InitializeValues(enter_vars); |
|
|
|
_loop_enters = enter_vars.ToList(); |
|
|
|
|
|
|
|
var merge_vars = enter_vars |
|
|
|
.Select(x => merge(new[] { x, x })) |
|
|
|
.ToArray(); |
|
|
|
// Finds the closest enclosing non-None control pivot. |
|
|
|
var outer_context = _outer_context; |
|
|
|
object control_pivot = null; |
|
|
|
while (outer_context != null && control_pivot == null) |
|
|
|
{ |
|
|
|
|
|
|
|
_pivot_for_pred = merge_vars[0]; |
|
|
|
} |
|
|
|
|
|
|
|
// Build the graph for pred. |
|
|
|
var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars); |
|
|
|
// var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays); |
|
|
|
var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0], default(TItem))); |
|
|
|
_pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond"); |
|
|
|
var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot)) |
|
|
|
.ToArray(); |
|
|
|
if (control_pivot != null) |
|
|
|
{ |
|
|
|
|
|
|
|
// Build the graph for body. |
|
|
|
var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); |
|
|
|
// Convert TensorArray flow variables inside the context back into |
|
|
|
// their associated TensorArrays for calling the body. |
|
|
|
var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); |
|
|
|
/*var body_result = body(packed_vars_for_body[0]); |
|
|
|
var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); |
|
|
|
|
|
|
|
// Store body_result to keep track of TensorArrays returned by body |
|
|
|
var original_body_result = new[] { body_result }; |
|
|
|
// Convert TensorArrays returned by body into their flow variables |
|
|
|
var result = new[] { body_result }; |
|
|
|
|
|
|
|
var next_vars = new List<Tensor>(); |
|
|
|
foreach (var (m, v) in zip(merge_vars, result)) |
|
|
|
next_vars.Add(_AddNextAndBackEdge(m, v)); |
|
|
|
|
|
|
|
// Add the exit ops. |
|
|
|
var exit_vars = switch_vars.Select(x => exit(x[0])).ToList(); |
|
|
|
_loop_exits = exit_vars; |
|
|
|
|
|
|
|
// Exit the loop. |
|
|
|
// ExitResult(exit_vars); |
|
|
|
return (original_body_result, exit_vars.ToArray());*/ |
|
|
|
} |
|
|
|
|
|
|
|
throw new NotImplementedException(""); |
|
|
|
_SetShapeInvariants(real_vars, enter_vars, shape_invariants); |
|
|
|
|
|
|
|
// Fix the control inputs and control flow context of these enter ops. |
|
|
|
_FixControlInputsAndContext(enter_vars); |
|
|
|
_InitializeValues(enter_vars); |
|
|
|
_loop_enters = enter_vars.ToList(); |
|
|
|
|
|
|
|
var merge_vars = enter_vars |
|
|
|
.Select(x => merge(new[] { x, x })) |
|
|
|
.ToArray(); |
|
|
|
|
|
|
|
_pivot_for_pred = merge_vars[0]; |
|
|
|
|
|
|
|
// Build the graph for pred. |
|
|
|
var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars); |
|
|
|
//var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays, expand_composites: true); |
|
|
|
var packed_vars = new LoopVar<TItem>((Tensor)merge_vars_with_tensor_arrays[0], |
|
|
|
(TItem)(object)new BodyItemInRnnWhileLoop((Tensor)merge_vars_with_tensor_arrays[1], |
|
|
|
new[] { (TensorArray)merge_vars_with_tensor_arrays[2] }, |
|
|
|
(Tensor)merge_vars_with_tensor_arrays[3])); |
|
|
|
var pp = pred(packed_vars); |
|
|
|
var c = ops.convert_to_tensor(pp); |
|
|
|
_pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond"); |
|
|
|
var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot)) |
|
|
|
.ToArray(); |
|
|
|
|
|
|
|
// Build the graph for body. |
|
|
|
var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); |
|
|
|
// Convert TensorArray flow variables inside the context back into |
|
|
|
// their associated TensorArrays for calling the body. |
|
|
|
var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); |
|
|
|
var body_result = body(original_loop_vars); |
|
|
|
var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); |
|
|
|
|
|
|
|
// Store body_result to keep track of TensorArrays returned by body |
|
|
|
var original_body_result = new[] { body_result }; |
|
|
|
// Convert TensorArrays returned by body into their flow variables |
|
|
|
var result = new[] { body_result }; |
|
|
|
|
|
|
|
var next_vars = new List<Tensor>(); |
|
|
|
//foreach (var (m, v) in zip(merge_vars, result)) |
|
|
|
//next_vars.Add(_AddNextAndBackEdge(m, v)); |
|
|
|
|
|
|
|
// Add the exit ops. |
|
|
|
var exit_vars = switch_vars.Select(x => exit(x[0])).ToList(); |
|
|
|
_loop_exits = exit_vars; |
|
|
|
|
|
|
|
// Exit the loop. |
|
|
|
// ExitResult(exit_vars); |
|
|
|
return (null, exit_vars.ToArray()); |
|
|
|
} |
|
|
|
|
|
|
|
private void _FixControlInputsAndContext(Tensor[] enters) |
|
|
|
@@ -258,6 +288,23 @@ namespace Tensorflow.Operations |
|
|
|
_values.Add(x.name); |
|
|
|
} |
|
|
|
|
|
|
|
public override Tensor AddValue(Tensor val) |
|
|
|
{ |
|
|
|
var result = val; |
|
|
|
var new_value = _values.Contains(val.name); |
|
|
|
new_value &= val.op._get_control_flow_context() != this; |
|
|
|
if (new_value) |
|
|
|
throw new NotImplementedException(""); |
|
|
|
else |
|
|
|
{ |
|
|
|
var actual_val = _external_values.ContainsKey(val.name) ? _external_values[val.name] : null; |
|
|
|
if (actual_val != null) |
|
|
|
result = actual_val as Tensor; |
|
|
|
} |
|
|
|
|
|
|
|
return result; |
|
|
|
} |
|
|
|
|
|
|
|
public override WhileContext GetWhileContext() |
|
|
|
{ |
|
|
|
return this; |
|
|
|
|