| @@ -42,7 +42,7 @@ namespace Tensorflow.Operations | |||||
| public override GradLoopState grad_state => _grad_state; | public override GradLoopState grad_state => _grad_state; | ||||
| public override bool back_prop => _back_prop; | public override bool back_prop => _back_prop; | ||||
| public WhileContext(int? maximum_iterations = null, | |||||
| public WhileContext(Tensor maximum_iterations = null, | |||||
| int parallel_iterations = 10, | int parallel_iterations = 10, | ||||
| bool back_prop = true, | bool back_prop = true, | ||||
| bool swap_memory = false, | bool swap_memory = false, | ||||
| @@ -64,7 +64,7 @@ namespace Tensorflow.Operations | |||||
| _grad_state = grad_state; | _grad_state = grad_state; | ||||
| } | } | ||||
| private void _init_from_args(int? maximum_iterations, | |||||
| private void _init_from_args(Tensor maximum_iterations, | |||||
| int parallel_iterations, | int parallel_iterations, | ||||
| bool back_prop, | bool back_prop, | ||||
| bool swap_memory, | bool swap_memory, | ||||
| @@ -107,9 +107,9 @@ namespace Tensorflow.Operations | |||||
| /// <summary> | /// <summary> | ||||
| /// Add the loop termination condition and body to the graph. | /// Add the loop termination condition and body to the graph. | ||||
| /// </summary> | /// </summary> | ||||
| public Tensor[] BuildLoop(Func<Tensor, Tensor> pred, | |||||
| Func<Tensor, Tensor> body, | |||||
| Tensor[] loop_vars, | |||||
| internal Tensor[] BuildLoop<TItem>(Func<Tensor, TItem, Tensor> pred, | |||||
| Func<Tensor, TItem, LoopVar<TItem>> body, | |||||
| TItem loop_vars, | |||||
| TensorShape shape_invariants, | TensorShape shape_invariants, | ||||
| bool return_same_structure) | bool return_same_structure) | ||||
| { | { | ||||
| @@ -131,88 +131,107 @@ namespace Tensorflow.Operations | |||||
| return packed_exit_vars as Tensor[]; | return packed_exit_vars as Tensor[]; | ||||
| } | } | ||||
| private (Tensor[], Tensor[]) _BuildLoop(Func<Tensor, Tensor> pred, | |||||
| Func<Tensor, Tensor> body, | |||||
| Tensor[] original_loop_vars, | |||||
| Tensor[] loop_vars, | |||||
| private Tensor _convert_tensorarray_to_flow<TItem>(TItem tensor_or_tensor_array) | |||||
| { | |||||
| if (tensor_or_tensor_array is TensorArray tensor_array) | |||||
| return tensor_array.flow; | |||||
| else if (tensor_or_tensor_array is Tensor tensor) | |||||
| return tensor; | |||||
| throw new NotImplementedException("_convert_tensorarray_to_flow"); | |||||
| } | |||||
| private (Tensor[], Tensor[]) _BuildLoop<TItem>(Func<Tensor, TItem, Tensor> pred, | |||||
| Func<Tensor, TItem, LoopVar<TItem>> body, | |||||
| TItem original_loop_vars, | |||||
| TItem loop_vars, | |||||
| TensorShape shape_invariants) | TensorShape shape_invariants) | ||||
| { | { | ||||
| var flat_loop_vars = original_loop_vars; | var flat_loop_vars = original_loop_vars; | ||||
| // Convert TensorArrays to their flow variables | |||||
| var loop_vars_tensor = nest.map_structure( | |||||
| _convert_tensorarray_to_flow, | |||||
| nest.flatten(loop_vars)); | |||||
| // Let the context know the loop variables so the loop variables | // Let the context know the loop variables so the loop variables | ||||
| // would be added in the outer contexts properly. | // would be added in the outer contexts properly. | ||||
| _InitializeValues(loop_vars); | |||||
| var real_vars = loop_vars; | |||||
| Tensor[] enter_vars = null; | |||||
| tf_with(ops.control_dependencies(null), delegate | |||||
| if (loop_vars is Tensor[] real_vars) | |||||
| { | { | ||||
| enter_vars = real_vars.Select(x => _Enter(x, | |||||
| _name, | |||||
| is_constant: false, | |||||
| parallel_iterations: _parallel_iterations, | |||||
| use_input_shape: shape_invariants == null)) | |||||
| .ToArray(); | |||||
| foreach(var x in enter_vars) | |||||
| _InitializeValues(real_vars); | |||||
| Tensor[] enter_vars = null; | |||||
| tf_with(ops.control_dependencies(null), delegate | |||||
| { | |||||
| enter_vars = real_vars.Select(x => _Enter(x, | |||||
| _name, | |||||
| is_constant: false, | |||||
| parallel_iterations: _parallel_iterations, | |||||
| use_input_shape: shape_invariants == null)) | |||||
| .ToArray(); | |||||
| foreach (var x in enter_vars) | |||||
| { | |||||
| x.graph.prevent_feeding(x); | |||||
| if (_outer_context != null) | |||||
| _outer_context.AddInnerOp(x.op); | |||||
| } | |||||
| }); | |||||
| // Finds the closest enclosing non-None control pivot. | |||||
| var outer_context = _outer_context; | |||||
| while (outer_context != null) | |||||
| { | { | ||||
| x.graph.prevent_feeding(x); | |||||
| if (_outer_context != null) | |||||
| _outer_context.AddInnerOp(x.op); | |||||
| } | } | ||||
| }); | |||||
| // Finds the closest enclosing non-None control pivot. | |||||
| var outer_context = _outer_context; | |||||
| while (outer_context != null) | |||||
| { | |||||
| _SetShapeInvariants(real_vars, enter_vars, shape_invariants); | |||||
| // Fix the control inputs and control flow context of these enter ops. | |||||
| _FixControlInputsAndContext(enter_vars); | |||||
| _InitializeValues(enter_vars); | |||||
| _loop_enters = enter_vars.ToList(); | |||||
| var merge_vars = enter_vars | |||||
| .Select(x => merge(new[] { x, x })) | |||||
| .ToArray(); | |||||
| _pivot_for_pred = merge_vars[0]; | |||||
| // Build the graph for pred. | |||||
| var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars); | |||||
| // var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays); | |||||
| var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0], default(TItem))); | |||||
| _pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond"); | |||||
| var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot)) | |||||
| .ToArray(); | |||||
| // Build the graph for body. | |||||
| var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); | |||||
| // Convert TensorArray flow variables inside the context back into | |||||
| // their associated TensorArrays for calling the body. | |||||
| var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); | |||||
| /*var body_result = body(packed_vars_for_body[0]); | |||||
| var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||||
| // Store body_result to keep track of TensorArrays returned by body | |||||
| var original_body_result = new[] { body_result }; | |||||
| // Convert TensorArrays returned by body into their flow variables | |||||
| var result = new[] { body_result }; | |||||
| var next_vars = new List<Tensor>(); | |||||
| foreach (var (m, v) in zip(merge_vars, result)) | |||||
| next_vars.Add(_AddNextAndBackEdge(m, v)); | |||||
| // Add the exit ops. | |||||
| var exit_vars = switch_vars.Select(x => exit(x[0])).ToList(); | |||||
| _loop_exits = exit_vars; | |||||
| // Exit the loop. | |||||
| // ExitResult(exit_vars); | |||||
| return (original_body_result, exit_vars.ToArray());*/ | |||||
| } | } | ||||
| _SetShapeInvariants(real_vars, enter_vars, shape_invariants); | |||||
| // Fix the control inputs and control flow context of these enter ops. | |||||
| _FixControlInputsAndContext(enter_vars); | |||||
| _InitializeValues(enter_vars); | |||||
| _loop_enters = enter_vars.ToList(); | |||||
| var merge_vars = enter_vars | |||||
| .Select(x => merge(new[] { x, x })) | |||||
| .ToArray(); | |||||
| _pivot_for_pred = merge_vars[0]; | |||||
| // Build the graph for pred. | |||||
| var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars); | |||||
| // var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays); | |||||
| var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0])); | |||||
| _pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond"); | |||||
| var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot)) | |||||
| .ToArray(); | |||||
| // Build the graph for body. | |||||
| var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); | |||||
| // Convert TensorArray flow variables inside the context back into | |||||
| // their associated TensorArrays for calling the body. | |||||
| var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); | |||||
| var body_result = body(packed_vars_for_body[0]); | |||||
| var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||||
| // Store body_result to keep track of TensorArrays returned by body | |||||
| var original_body_result = new[] { body_result }; | |||||
| // Convert TensorArrays returned by body into their flow variables | |||||
| var result = new[] { body_result }; | |||||
| var next_vars = new List<Tensor>(); | |||||
| foreach (var (m, v) in zip(merge_vars, result)) | |||||
| next_vars.Add(_AddNextAndBackEdge(m, v)); | |||||
| // Add the exit ops. | |||||
| var exit_vars = switch_vars.Select(x => exit(x[0])).ToList(); | |||||
| _loop_exits = exit_vars; | |||||
| // Exit the loop. | |||||
| // ExitResult(exit_vars); | |||||
| return (original_body_result, exit_vars.ToArray()); | |||||
| throw new NotImplementedException(""); | |||||
| } | } | ||||
| private void _FixControlInputsAndContext(Tensor[] enters) | private void _FixControlInputsAndContext(Tensor[] enters) | ||||