| @@ -42,7 +42,7 @@ namespace Tensorflow.Operations | |||
| public override GradLoopState grad_state => _grad_state; | |||
| public override bool back_prop => _back_prop; | |||
| public WhileContext(int? maximum_iterations = null, | |||
| public WhileContext(Tensor maximum_iterations = null, | |||
| int parallel_iterations = 10, | |||
| bool back_prop = true, | |||
| bool swap_memory = false, | |||
| @@ -64,7 +64,7 @@ namespace Tensorflow.Operations | |||
| _grad_state = grad_state; | |||
| } | |||
| private void _init_from_args(int? maximum_iterations, | |||
| private void _init_from_args(Tensor maximum_iterations, | |||
| int parallel_iterations, | |||
| bool back_prop, | |||
| bool swap_memory, | |||
| @@ -107,9 +107,9 @@ namespace Tensorflow.Operations | |||
| /// <summary> | |||
| /// Add the loop termination condition and body to the graph. | |||
| /// </summary> | |||
| public Tensor[] BuildLoop(Func<Tensor, Tensor> pred, | |||
| Func<Tensor, Tensor> body, | |||
| Tensor[] loop_vars, | |||
| internal Tensor[] BuildLoop<TItem>(Func<Tensor, TItem, Tensor> pred, | |||
| Func<Tensor, TItem, LoopVar<TItem>> body, | |||
| TItem loop_vars, | |||
| TensorShape shape_invariants, | |||
| bool return_same_structure) | |||
| { | |||
| @@ -131,88 +131,107 @@ namespace Tensorflow.Operations | |||
| return packed_exit_vars as Tensor[]; | |||
| } | |||
| private (Tensor[], Tensor[]) _BuildLoop(Func<Tensor, Tensor> pred, | |||
| Func<Tensor, Tensor> body, | |||
| Tensor[] original_loop_vars, | |||
| Tensor[] loop_vars, | |||
| private Tensor _convert_tensorarray_to_flow<TItem>(TItem tensor_or_tensor_array) | |||
| { | |||
| if (tensor_or_tensor_array is TensorArray tensor_array) | |||
| return tensor_array.flow; | |||
| else if (tensor_or_tensor_array is Tensor tensor) | |||
| return tensor; | |||
| throw new NotImplementedException("_convert_tensorarray_to_flow"); | |||
| } | |||
| private (Tensor[], Tensor[]) _BuildLoop<TItem>(Func<Tensor, TItem, Tensor> pred, | |||
| Func<Tensor, TItem, LoopVar<TItem>> body, | |||
| TItem original_loop_vars, | |||
| TItem loop_vars, | |||
| TensorShape shape_invariants) | |||
| { | |||
| var flat_loop_vars = original_loop_vars; | |||
| // Convert TensorArrays to their flow variables | |||
| var loop_vars_tensor = nest.map_structure( | |||
| _convert_tensorarray_to_flow, | |||
| nest.flatten(loop_vars)); | |||
| // Let the context know the loop variables so the loop variables | |||
| // would be added in the outer contexts properly. | |||
| _InitializeValues(loop_vars); | |||
| var real_vars = loop_vars; | |||
| Tensor[] enter_vars = null; | |||
| tf_with(ops.control_dependencies(null), delegate | |||
| if (loop_vars is Tensor[] real_vars) | |||
| { | |||
| enter_vars = real_vars.Select(x => _Enter(x, | |||
| _name, | |||
| is_constant: false, | |||
| parallel_iterations: _parallel_iterations, | |||
| use_input_shape: shape_invariants == null)) | |||
| .ToArray(); | |||
| foreach(var x in enter_vars) | |||
| _InitializeValues(real_vars); | |||
| Tensor[] enter_vars = null; | |||
| tf_with(ops.control_dependencies(null), delegate | |||
| { | |||
| enter_vars = real_vars.Select(x => _Enter(x, | |||
| _name, | |||
| is_constant: false, | |||
| parallel_iterations: _parallel_iterations, | |||
| use_input_shape: shape_invariants == null)) | |||
| .ToArray(); | |||
| foreach (var x in enter_vars) | |||
| { | |||
| x.graph.prevent_feeding(x); | |||
| if (_outer_context != null) | |||
| _outer_context.AddInnerOp(x.op); | |||
| } | |||
| }); | |||
| // Finds the closest enclosing non-None control pivot. | |||
| var outer_context = _outer_context; | |||
| while (outer_context != null) | |||
| { | |||
| x.graph.prevent_feeding(x); | |||
| if (_outer_context != null) | |||
| _outer_context.AddInnerOp(x.op); | |||
| } | |||
| }); | |||
| // Finds the closest enclosing non-None control pivot. | |||
| var outer_context = _outer_context; | |||
| while (outer_context != null) | |||
| { | |||
| _SetShapeInvariants(real_vars, enter_vars, shape_invariants); | |||
| // Fix the control inputs and control flow context of these enter ops. | |||
| _FixControlInputsAndContext(enter_vars); | |||
| _InitializeValues(enter_vars); | |||
| _loop_enters = enter_vars.ToList(); | |||
| var merge_vars = enter_vars | |||
| .Select(x => merge(new[] { x, x })) | |||
| .ToArray(); | |||
| _pivot_for_pred = merge_vars[0]; | |||
| // Build the graph for pred. | |||
| var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars); | |||
| // var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays); | |||
| var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0], default(TItem))); | |||
| _pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond"); | |||
| var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot)) | |||
| .ToArray(); | |||
| // Build the graph for body. | |||
| var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); | |||
| // Convert TensorArray flow variables inside the context back into | |||
| // their associated TensorArrays for calling the body. | |||
| var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); | |||
| /*var body_result = body(packed_vars_for_body[0]); | |||
| var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||
| // Store body_result to keep track of TensorArrays returned by body | |||
| var original_body_result = new[] { body_result }; | |||
| // Convert TensorArrays returned by body into their flow variables | |||
| var result = new[] { body_result }; | |||
| var next_vars = new List<Tensor>(); | |||
| foreach (var (m, v) in zip(merge_vars, result)) | |||
| next_vars.Add(_AddNextAndBackEdge(m, v)); | |||
| // Add the exit ops. | |||
| var exit_vars = switch_vars.Select(x => exit(x[0])).ToList(); | |||
| _loop_exits = exit_vars; | |||
| // Exit the loop. | |||
| // ExitResult(exit_vars); | |||
| return (original_body_result, exit_vars.ToArray());*/ | |||
| } | |||
| _SetShapeInvariants(real_vars, enter_vars, shape_invariants); | |||
| // Fix the control inputs and control flow context of these enter ops. | |||
| _FixControlInputsAndContext(enter_vars); | |||
| _InitializeValues(enter_vars); | |||
| _loop_enters = enter_vars.ToList(); | |||
| var merge_vars = enter_vars | |||
| .Select(x => merge(new[] { x, x })) | |||
| .ToArray(); | |||
| _pivot_for_pred = merge_vars[0]; | |||
| // Build the graph for pred. | |||
| var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars); | |||
| // var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays); | |||
| var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0])); | |||
| _pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond"); | |||
| var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot)) | |||
| .ToArray(); | |||
| // Build the graph for body. | |||
| var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); | |||
| // Convert TensorArray flow variables inside the context back into | |||
| // their associated TensorArrays for calling the body. | |||
| var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); | |||
| var body_result = body(packed_vars_for_body[0]); | |||
| var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); | |||
| // Store body_result to keep track of TensorArrays returned by body | |||
| var original_body_result = new[] { body_result }; | |||
| // Convert TensorArrays returned by body into their flow variables | |||
| var result = new[] { body_result }; | |||
| var next_vars = new List<Tensor>(); | |||
| foreach (var (m, v) in zip(merge_vars, result)) | |||
| next_vars.Add(_AddNextAndBackEdge(m, v)); | |||
| // Add the exit ops. | |||
| var exit_vars = switch_vars.Select(x => exit(x[0])).ToList(); | |||
| _loop_exits = exit_vars; | |||
| // Exit the loop. | |||
| // ExitResult(exit_vars); | |||
| return (original_body_result, exit_vars.ToArray()); | |||
| throw new NotImplementedException(""); | |||
| } | |||
| private void _FixControlInputsAndContext(Tensor[] enters) | |||