| @@ -1,5 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Diagnostics; | |||||
| using System.Text; | using System.Text; | ||||
| using NumSharp; | using NumSharp; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| @@ -21,7 +22,12 @@ namespace Tensorflow.Hub | |||||
| images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]); | images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]); | ||||
| images.astype(dataType); | images.astype(dataType); | ||||
| // for debug np.multiply performance | |||||
| var sw = new Stopwatch(); | |||||
| sw.Start(); | |||||
| images = np.multiply(images, 1.0f / 255.0f); | images = np.multiply(images, 1.0f / 255.0f); | ||||
| sw.Stop(); | |||||
| Console.WriteLine($"{sw.ElapsedMilliseconds}ms"); | |||||
| Data = images; | Data = images; | ||||
| labels.astype(dataType); | labels.astype(dataType); | ||||
| @@ -14,10 +14,29 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public static partial class tf | public static partial class tf | ||||
| { | { | ||||
| public static Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars, | |||||
| TensorShape shape_invariants = null, | |||||
| int parallel_iterations = 10, | |||||
| bool back_prop = true, | |||||
| bool swap_memory = false, | |||||
| string name = null, | |||||
| int? maximum_iterations = null, | |||||
| bool return_same_structure = false) | |||||
| => control_flow_ops.while_loop(cond, body, loop_vars, | |||||
| shape_invariants: shape_invariants, | |||||
| parallel_iterations: parallel_iterations, | |||||
| back_prop: back_prop, | |||||
| swap_memory: swap_memory, | |||||
| name: name, | |||||
| maximum_iterations: maximum_iterations, | |||||
| return_same_structure: return_same_structure); | |||||
| public static _ControlDependenciesController control_dependencies(Operation[] control_inputs) | public static _ControlDependenciesController control_dependencies(Operation[] control_inputs) | ||||
| => ops.control_dependencies(control_inputs); | => ops.control_dependencies(control_inputs); | ||||
| } | } | ||||
| @@ -39,8 +39,8 @@ namespace Tensorflow | |||||
| public static Tensor asin(Tensor x, string name = null) | public static Tensor asin(Tensor x, string name = null) | ||||
| => gen_math_ops.asin(x, name); | => gen_math_ops.asin(x, name); | ||||
| public static Tensor add<Tx, Ty>(Tx a, Ty b) | |||||
| => gen_math_ops.add(a, b); | |||||
| public static Tensor add<Tx, Ty>(Tx a, Ty b, string name = null) | |||||
| => gen_math_ops.add(a, b, name: name); | |||||
| /// <summary> | /// <summary> | ||||
| /// Computes atan of x element-wise. | /// Computes atan of x element-wise. | ||||
| @@ -33,7 +33,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="input_ops">The data input ops for an op to be created.</param> | /// <param name="input_ops">The data input ops for an op to be created.</param> | ||||
| /// <returns>A list of control inputs for the op to be created.</returns> | /// <returns>A list of control inputs for the op to be created.</returns> | ||||
| private ITensorOrOperation[] _control_dependencies_for_inputs(ITensorOrOperation[] input_ops) | |||||
| public ITensorOrOperation[] _control_dependencies_for_inputs(ITensorOrOperation[] input_ops) | |||||
| { | { | ||||
| var ret = new List<ITensorOrOperation>(); | var ret = new List<ITensorOrOperation>(); | ||||
| @@ -53,6 +53,11 @@ namespace Tensorflow.Operations | |||||
| protected Stack<ControlFlowContext> _context_stack; | protected Stack<ControlFlowContext> _context_stack; | ||||
| protected ControlFlowContext _outer_context; | protected ControlFlowContext _outer_context; | ||||
| /// <summary> | |||||
| /// The keys are the names of tensors referenced by but external to this | |||||
| /// context. Each value is the Tensor that should be used by this context to | |||||
| /// access the key value (e.g. a switch output guarding a cond input value). | |||||
| /// </summary> | |||||
| protected Dictionary<string, ITensorOrOperation> _external_values; | protected Dictionary<string, ITensorOrOperation> _external_values; | ||||
| public ControlFlowContext() | public ControlFlowContext() | ||||
| @@ -68,6 +73,12 @@ namespace Tensorflow.Operations | |||||
| _outer_context = ops.get_default_graph()._get_control_flow_context(); | _outer_context = ops.get_default_graph()._get_control_flow_context(); | ||||
| if (values_def != null) | if (values_def != null) | ||||
| _init_values_from_proto(values_def, import_scope: import_scope); | _init_values_from_proto(values_def, import_scope: import_scope); | ||||
| else | |||||
| { | |||||
| _values = new HashSet<string>(); | |||||
| _external_values = new Dictionary<string, ITensorOrOperation>(); | |||||
| } | |||||
| } | } | ||||
| public void __enter__() | public void __enter__() | ||||
| @@ -114,6 +125,27 @@ namespace Tensorflow.Operations | |||||
| graph._set_control_flow_context(this); | graph._set_control_flow_context(this); | ||||
| } | } | ||||
| protected virtual Tensor _Enter(Tensor data, string frame_name, | |||||
| bool is_constant = false, | |||||
| int parallel_iterations = 10, | |||||
| bool use_ref = true, | |||||
| bool use_input_shape = true, | |||||
| string name = null) | |||||
| { | |||||
| Tensor result; | |||||
| data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref: true); | |||||
| if (data.dtype.is_ref_dtype() && use_ref) | |||||
| throw new NotImplementedException("_Enter"); | |||||
| else | |||||
| result = gen_control_flow_ops.enter( | |||||
| data, frame_name, is_constant, parallel_iterations, name: name); | |||||
| if (use_input_shape) | |||||
| result.SetShape(data.TensorShape); | |||||
| return result; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Exit this control flow context. | /// Exit this control flow context. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -184,6 +216,10 @@ namespace Tensorflow.Operations | |||||
| return true; | return true; | ||||
| } | } | ||||
| protected virtual bool _IsInOuterContext(Operation op) | |||||
| { | |||||
| throw new NotImplementedException("_IsInOuterContext"); | |||||
| } | |||||
| protected virtual void _RemoveExternalControlEdges(Operation op) | protected virtual void _RemoveExternalControlEdges(Operation op) | ||||
| { | { | ||||
| @@ -15,8 +15,12 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using Tensorflow.Operations.ControlFlows; | using Tensorflow.Operations.ControlFlows; | ||||
| using Tensorflow.Util; | |||||
| using static Tensorflow.Python; | using static Tensorflow.Python; | ||||
| using static Tensorflow.control_flow_ops; | |||||
| namespace Tensorflow.Operations | namespace Tensorflow.Operations | ||||
| { | { | ||||
| @@ -32,10 +36,14 @@ namespace Tensorflow.Operations | |||||
| bool _swap_memory; | bool _swap_memory; | ||||
| Tensor _pivot_for_pred; | Tensor _pivot_for_pred; | ||||
| Tensor _pivot_for_body; | Tensor _pivot_for_body; | ||||
| Tensor[] _loop_exits; | |||||
| Tensor[] _loop_enters; | |||||
| List<Tensor> _loop_exits; | |||||
| List<Tensor> _loop_enters; | |||||
| Graph _graph; | |||||
| public override GradLoopState grad_state => _grad_state; | |||||
| public override bool back_prop => _back_prop; | |||||
| public WhileContext(int parallel_iterations = 10, | |||||
| public WhileContext(int? maximum_iterations = null, | |||||
| int parallel_iterations = 10, | |||||
| bool back_prop = true, | bool back_prop = true, | ||||
| bool swap_memory = false, | bool swap_memory = false, | ||||
| string name = "while_context", | string name = "while_context", | ||||
| @@ -49,12 +57,27 @@ namespace Tensorflow.Operations | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| __init__(); | |||||
| _init_from_args(maximum_iterations, parallel_iterations, back_prop, swap_memory, name); | |||||
| } | } | ||||
| _grad_state = grad_state; | _grad_state = grad_state; | ||||
| } | } | ||||
| private void _init_from_args(int? maximum_iterations, | |||||
| int parallel_iterations, | |||||
| bool back_prop, | |||||
| bool swap_memory, | |||||
| string name) | |||||
| { | |||||
| _name = ops.get_default_graph().unique_name(name); | |||||
| _back_prop = back_prop; | |||||
| _swap_memory = swap_memory; | |||||
| _loop_exits = new List<Tensor>(); | |||||
| _loop_enters = new List<Tensor>(); | |||||
| _graph = ops.get_default_graph(); | |||||
| } | |||||
| private void _init_from_proto(WhileContextDef context_def, string import_scope = null) | private void _init_from_proto(WhileContextDef context_def, string import_scope = null) | ||||
| { | { | ||||
| var g = ops.get_default_graph(); | var g = ops.get_default_graph(); | ||||
| @@ -70,26 +93,156 @@ namespace Tensorflow.Operations | |||||
| // The boolean tensor for loop termination condition. | // The boolean tensor for loop termination condition. | ||||
| _pivot = g.as_graph_element(ops.prepend_name_scope(context_def.PivotName, import_scope)) as Tensor; | _pivot = g.as_graph_element(ops.prepend_name_scope(context_def.PivotName, import_scope)) as Tensor; | ||||
| // The list of exit tensors for loop variables. | // The list of exit tensors for loop variables. | ||||
| _loop_exits = new Tensor[context_def.LoopExitNames.Count]; | |||||
| _loop_exits = new List<Tensor>(); | |||||
| foreach (var (i, exit_name) in enumerate(context_def.LoopExitNames)) | foreach (var (i, exit_name) in enumerate(context_def.LoopExitNames)) | ||||
| _loop_exits[i] = g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope)) as Tensor; | |||||
| _loop_exits.Add(g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope)) as Tensor); | |||||
| // The list of enter tensors for loop variables. | // The list of enter tensors for loop variables. | ||||
| _loop_enters = new Tensor[context_def.LoopEnterNames.Count]; | |||||
| _loop_enters = new List<Tensor>(); | |||||
| foreach (var (i, enter_name) in enumerate(context_def.LoopEnterNames)) | foreach (var (i, enter_name) in enumerate(context_def.LoopEnterNames)) | ||||
| _loop_enters[i] = g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope)) as Tensor; | |||||
| _loop_enters.Add(g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope)) as Tensor); | |||||
| __init__(values_def: context_def.ValuesDef, import_scope: import_scope); | __init__(values_def: context_def.ValuesDef, import_scope: import_scope); | ||||
| } | } | ||||
| public override WhileContext GetWhileContext() | |||||
| /// <summary> | |||||
| /// Add the loop termination condition and body to the graph. | |||||
| /// </summary> | |||||
| public Tensor[] BuildLoop(Func<Tensor, Tensor> pred, | |||||
| Func<Tensor, Tensor> body, | |||||
| Tensor[] loop_vars, | |||||
| TensorShape shape_invariants, | |||||
| bool return_same_structure) | |||||
| { | { | ||||
| return this; | |||||
| // Keep original_loop_vars to identify which are TensorArrays | |||||
| var original_loop_vars = loop_vars; | |||||
| // Convert TensorArrays to their flow variables | |||||
| Enter(); | |||||
| var(original_body_result, exit_vars) = _BuildLoop( | |||||
| pred, body, original_loop_vars, loop_vars, shape_invariants); | |||||
| Exit(); | |||||
| var flat_result = original_body_result; | |||||
| var exit_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_result, exit_vars); | |||||
| var packed_exit_vars = nest.pack_sequence_as( | |||||
| structure: original_body_result, | |||||
| flat_sequence: exit_vars_with_tensor_arrays); | |||||
| return packed_exit_vars as Tensor[]; | |||||
| } | } | ||||
| private (Tensor[], Tensor[]) _BuildLoop(Func<Tensor, Tensor> pred, | |||||
| Func<Tensor, Tensor> body, | |||||
| Tensor[] original_loop_vars, | |||||
| Tensor[] loop_vars, | |||||
| TensorShape shape_invariants) | |||||
| { | |||||
| var flat_loop_vars = original_loop_vars; | |||||
| public override GradLoopState grad_state => _grad_state; | |||||
| // Let the context know the loop variables so the loop variables | |||||
| // would be added in the outer contexts properly. | |||||
| _InitializeValues(loop_vars); | |||||
| var real_vars = loop_vars; | |||||
| Tensor[] enter_vars = null; | |||||
| tf_with(ops.control_dependencies(null), delegate | |||||
| { | |||||
| enter_vars = real_vars.Select(x => _Enter(x, | |||||
| _name, | |||||
| is_constant: false, | |||||
| parallel_iterations: _parallel_iterations, | |||||
| use_input_shape: shape_invariants == null)) | |||||
| .ToArray(); | |||||
| public override bool back_prop => _back_prop; | |||||
| foreach(var x in enter_vars) | |||||
| { | |||||
| x.graph.prevent_feeding(x); | |||||
| if (_outer_context != null) | |||||
| _outer_context.AddInnerOp(x.op); | |||||
| } | |||||
| }); | |||||
| // Finds the closest enclosing non-None control pivot. | |||||
| var outer_context = _outer_context; | |||||
| while (outer_context != null) | |||||
| { | |||||
| } | |||||
| _SetShapeInvariants(real_vars, enter_vars, shape_invariants); | |||||
| // Fix the control inputs and control flow context of these enter ops. | |||||
| _FixControlInputsAndContext(enter_vars); | |||||
| _InitializeValues(enter_vars); | |||||
| _loop_enters = enter_vars.ToList(); | |||||
| var merge_vars = enter_vars | |||||
| .Select(x => merge(new[] { x, x })) | |||||
| .ToArray(); | |||||
| _pivot_for_pred = merge_vars[0]; | |||||
| // Build the graph for pred. | |||||
| var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars); | |||||
| // var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays); | |||||
| var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0])); | |||||
| _pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond"); | |||||
| var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot)) | |||||
| .ToArray(); | |||||
| // Build the graph for body. | |||||
| var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); | |||||
| // Convert TensorArray flow variables inside the context back into | |||||
| // their associated TensorArrays for calling the body. | |||||
| var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); | |||||
| var body_result = body(packed_vars_for_body[0]); | |||||
| var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||||
| // Store body_result to keep track of TensorArrays returned by body | |||||
| var original_body_result = new[] { body_result }; | |||||
| // Convert TensorArrays returned by body into their flow variables | |||||
| var result = new[] { body_result }; | |||||
| var next_vars = new List<Tensor>(); | |||||
| foreach (var (m, v) in zip(merge_vars, result)) | |||||
| next_vars.Add(_AddNextAndBackEdge(m, v)); | |||||
| // Add the exit ops. | |||||
| var exit_vars = switch_vars.Select(x => exit(x[0])).ToList(); | |||||
| _loop_exits = exit_vars; | |||||
| // Exit the loop. | |||||
| // ExitResult(exit_vars); | |||||
| return (original_body_result, exit_vars.ToArray()); | |||||
| } | |||||
| private void _FixControlInputsAndContext(Tensor[] enters) | |||||
| { | |||||
| var graph = ops.get_default_graph(); | |||||
| foreach(var e in enters) | |||||
| { | |||||
| var inp_op = e.op.inputs[0].op; | |||||
| var control_inputs = graph._control_dependencies_for_inputs(new[] { inp_op }); | |||||
| // op for op in control_inputs if self._IsInOuterContext(op) | |||||
| var outer_control_inputs = control_inputs.Where(x => _IsInOuterContext(x.op)) | |||||
| .Select(x => x.op) | |||||
| .ToArray(); | |||||
| e.op._set_control_flow_context(this); | |||||
| e.op._add_control_inputs(outer_control_inputs); | |||||
| graph._record_op_seen_by_control_dependencies(e.op); | |||||
| } | |||||
| } | |||||
| private void _InitializeValues(Tensor[] values) | |||||
| { | |||||
| _values = new HashSet<string>(); | |||||
| foreach(var x in values) | |||||
| _values.Add(x.name); | |||||
| } | |||||
| public override WhileContext GetWhileContext() | |||||
| { | |||||
| return this; | |||||
| } | |||||
| public WhileContext from_proto(WhileContextDef proto, string import_scope) | public WhileContext from_proto(WhileContextDef proto, string import_scope) | ||||
| { | { | ||||
| @@ -141,30 +141,57 @@ namespace Tensorflow.Operations | |||||
| string base_name = null; | string base_name = null; | ||||
| tf_with(ops.name_scope("dynamic_rnn"), scope => base_name = scope); | tf_with(ops.name_scope("dynamic_rnn"), scope => base_name = scope); | ||||
| Func<string, TensorShape, TF_DataType, Tensor> _create_ta = (name, element_shape, dtype_) => | |||||
| Func<string, TensorShape, TF_DataType, TensorArray> _create_ta = (name, element_shape, dtype_) => | |||||
| { | { | ||||
| new TensorArray(dtype: dtype_, | |||||
| var ta = new TensorArray(dtype: dtype_, | |||||
| size: time_steps, | size: time_steps, | ||||
| element_shape: element_shape, | element_shape: element_shape, | ||||
| tensor_array_name: base_name + name); | tensor_array_name: base_name + name); | ||||
| throw new NotImplementedException(""); | |||||
| return ta; | |||||
| }; | }; | ||||
| bool in_graph_mode = true; | bool in_graph_mode = true; | ||||
| var output_ta = new List<TensorArray>(); | |||||
| var input_ta = new List<TensorArray>(); | |||||
| if (in_graph_mode) | if (in_graph_mode) | ||||
| { | { | ||||
| foreach(var (i, out_size) in enumerate(flat_output_size)) | |||||
| foreach (var (i, out_size) in enumerate(flat_output_size)) | |||||
| { | { | ||||
| _create_ta($"output_{i}", | |||||
| output_ta.Add(_create_ta($"output_{i}", | |||||
| new TensorShape(const_batch_size).concatenate( | new TensorShape(const_batch_size).concatenate( | ||||
| _maybe_tensor_shape_from_tensor(out_size)), | _maybe_tensor_shape_from_tensor(out_size)), | ||||
| _infer_state_dtype(dtype, state)); | |||||
| _infer_state_dtype(dtype, state))); | |||||
| } | |||||
| foreach (var (i, flat_input_i) in enumerate(flat_input)) | |||||
| { | |||||
| input_ta.Add(_create_ta($"input_{i}", | |||||
| new TensorShape(flat_input_i.dims.Skip(1).ToArray()), | |||||
| flat_input_i.dtype)); | |||||
| } | |||||
| for (int i = 0; i < input_ta.Count; i++) | |||||
| { | |||||
| var (ta, input_) = (input_ta[0], flat_input[0]); | |||||
| } | } | ||||
| } | } | ||||
| // Make sure that we run at least 1 step, if necessary, to ensure | |||||
| // the TensorArrays pick up the dynamic shape. | |||||
| Tensor loop_bound; | |||||
| if (in_graph_mode) | |||||
| loop_bound = math_ops.minimum( | |||||
| time_steps, math_ops.maximum(1, max_sequence_length)); | |||||
| /*Func<Tensor, Tensor> cond = (ctime) => | |||||
| { | |||||
| return null; | |||||
| }; | |||||
| control_flow_ops.while_loop( | |||||
| cond: cond, | |||||
| body = );*/ | |||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| } | } | ||||
| @@ -26,6 +26,44 @@ namespace Tensorflow | |||||
| { | { | ||||
| public class control_flow_ops | public class control_flow_ops | ||||
| { | { | ||||
| public static Tensor _AddNextAndBackEdge(Tensor m, Tensor v, bool enforce_shape_invariant = true) | |||||
| { | |||||
| v = ops.convert_to_tensor(v); | |||||
| v = _NextIteration(v); | |||||
| if (enforce_shape_invariant) | |||||
| _EnforceShapeInvariant(m, v); | |||||
| m.op._update_input(1, v); | |||||
| return v; | |||||
| } | |||||
| /// <summary> | |||||
| /// Check if the shapes of the loops variables are invariants. | |||||
| /// </summary> | |||||
| /// <param name="merge_var"></param> | |||||
| /// <param name="next_var"></param> | |||||
| public static void _EnforceShapeInvariant(Tensor merge_var, Tensor next_var) | |||||
| { | |||||
| } | |||||
| public static Tensor exit(Tensor data, string name = null) | |||||
| { | |||||
| data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref: true); | |||||
| if (data.dtype.is_ref_dtype()) | |||||
| return gen_control_flow_ops.ref_exit(data, name: name); | |||||
| else | |||||
| return gen_control_flow_ops._exit(data, name: name); | |||||
| } | |||||
| public static Tensor _NextIteration(Tensor data, string name = null) | |||||
| { | |||||
| data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref: true); | |||||
| if (data.dtype.is_ref_dtype()) | |||||
| return gen_control_flow_ops.ref_next_iteration(data, name: name); | |||||
| else | |||||
| return gen_control_flow_ops.next_iteration(data, name: name); | |||||
| } | |||||
| public static Operation Assert(Tensor condition, object[] data, int? summarize = null, string name = null) | public static Operation Assert(Tensor condition, object[] data, int? summarize = null, string name = null) | ||||
| { | { | ||||
| return tf_with(ops.name_scope(name, "Assert", new { condition, data }), scope => | return tf_with(ops.name_scope(name, "Assert", new { condition, data }), scope => | ||||
| @@ -213,6 +251,14 @@ namespace Tensorflow | |||||
| return gen_array_ops.identity(data, name: name); | return gen_array_ops.identity(data, name: name); | ||||
| } | } | ||||
| public static void _SetShapeInvariants(Tensor[] input_vars, Tensor[] enter_vars, TensorShape shapes = null) | |||||
| { | |||||
| if (shapes == null) | |||||
| return; | |||||
| throw new NotImplementedException("_SetShapeInvariants"); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Forwards `data` to an output determined by `pred`. | /// Forwards `data` to an output determined by `pred`. | ||||
| /// If `pred` is false, the `data` input is forwarded to the first output. | /// If `pred` is false, the `data` input is forwarded to the first output. | ||||
| @@ -516,10 +562,52 @@ namespace Tensorflow | |||||
| throw new NotImplementedException("ZerosLikeOutsideLoop"); | throw new NotImplementedException("ZerosLikeOutsideLoop"); | ||||
| } | } | ||||
| // TODO | |||||
| public static void while_loop(Func<Tensor, Tensor> func, Func<Tensor, Tensor> func1, Tensor[] tensors, int? i) | |||||
| /// <summary> | |||||
| /// Repeat `body` while the condition `cond` is true. | |||||
| /// </summary> | |||||
| /// <param name="cond"></param> | |||||
| /// <param name="body"></param> | |||||
| /// <param name="loop_vars"></param> | |||||
| /// <param name="i"></param> | |||||
| public static Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars, | |||||
| TensorShape shape_invariants = null, | |||||
| int parallel_iterations = 10, | |||||
| bool back_prop = true, | |||||
| bool swap_memory = false, | |||||
| string name = null, | |||||
| int? maximum_iterations = null, | |||||
| bool return_same_structure = false) | |||||
| { | { | ||||
| throw new NotImplementedException(); | |||||
| tf_with(ops.name_scope(name, "while", loop_vars), scope => | |||||
| { | |||||
| if (loop_vars == null || loop_vars.Length == 0) | |||||
| throw new ValueError("No loop variables provided"); | |||||
| if (cond == null) | |||||
| throw new ValueError("cond must be callable."); | |||||
| if (body == null) | |||||
| throw new ValueError("body must be callable."); | |||||
| if (parallel_iterations < 1) | |||||
| throw new ValueError("parallel_iterations must be a positive integer."); | |||||
| var loop_context = new WhileContext( | |||||
| maximum_iterations: maximum_iterations, | |||||
| parallel_iterations: parallel_iterations, | |||||
| back_prop: back_prop, | |||||
| swap_memory: swap_memory); | |||||
| if (loop_context.outer_context == null) | |||||
| ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context); | |||||
| var results = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants, | |||||
| return_same_structure); | |||||
| if (maximum_iterations != null) | |||||
| return results[1]; | |||||
| else | |||||
| return results[0]; | |||||
| }); | |||||
| throw new NotImplementedException("while_loop"); | |||||
| } | } | ||||
| } | } | ||||
| @@ -20,6 +20,93 @@ namespace Tensorflow | |||||
| { | { | ||||
| public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | ||||
| /// <summary> | |||||
| /// Creates or finds a child frame, and makes `data` available to the child frame. | |||||
| /// </summary> | |||||
| /// <param name="data"></param> | |||||
| /// <param name="frame_name"></param> | |||||
| /// <param name="is_constant"></param> | |||||
| /// <param name="parallel_iterations"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor enter(Tensor data, string frame_name = "frame_name", bool is_constant = false, int parallel_iterations = 10, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("Enter", name, new | |||||
| { | |||||
| data, | |||||
| frame_name, | |||||
| is_constant, | |||||
| parallel_iterations | |||||
| }); | |||||
| return _op.output; | |||||
| } | |||||
| /// <summary> | |||||
| /// Forwards the input to the output. | |||||
| /// </summary> | |||||
| /// <param name="input"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor loop_cond(Tensor input, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("LoopCond", name, new { input }); | |||||
| return _op.output; | |||||
| } | |||||
| /// <summary> | |||||
| /// Makes its input available to the next iteration. | |||||
| /// </summary> | |||||
| /// <param name="data"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor ref_next_iteration(Tensor data, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("RefNextIteration", name, new { data }); | |||||
| return _op; | |||||
| } | |||||
| /// <summary> | |||||
| /// Makes its input available to the next iteration. | |||||
| /// </summary> | |||||
| /// <param name="data"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor next_iteration(Tensor data, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("NextIteration", name, new { data }); | |||||
| return _op; | |||||
| } | |||||
| /// <summary> | |||||
| /// Exits the current frame to its parent frame. | |||||
| /// </summary> | |||||
| /// <param name="data"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor ref_exit(Tensor data, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("RefExit", name, new { data }); | |||||
| return _op; | |||||
| } | |||||
| /// <summary> | |||||
| /// Exits the current frame to its parent frame. | |||||
| /// </summary> | |||||
| /// <param name="data"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor _exit(Tensor data, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("Exit", name, new { data }); | |||||
| return _op; | |||||
| } | |||||
| public static Operation no_op(string name = null) | public static Operation no_op(string name = null) | ||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("NoOp", name, null); | var _op = _op_def_lib._apply_op_helper("NoOp", name, null); | ||||
| @@ -516,6 +516,9 @@ namespace Tensorflow | |||||
| }); | }); | ||||
| } | } | ||||
| public static Tensor minimum<Tx, Ty>(Tx x, Ty y, string name = null) | |||||
| => gen_math_ops.minimum(x, y, name: name); | |||||
| public static Tensor maximum<Tx, Ty>(Tx x, Ty y, string name = null) | public static Tensor maximum<Tx, Ty>(Tx x, Ty y, string name = null) | ||||
| => gen_math_ops.maximum(x, y, name: name); | => gen_math_ops.maximum(x, y, name: name); | ||||
| @@ -416,5 +416,6 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| public int tensor_int_val { get; set; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -8,6 +8,18 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||||
| [TestClass] | [TestClass] | ||||
| public class WhileContextTestCase : PythonTest | public class WhileContextTestCase : PythonTest | ||||
| { | { | ||||
| /// <summary> | |||||
| /// https://www.tensorflow.org/api_docs/python/tf/while_loop | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public void SimpleWhileLoop() | |||||
| { | |||||
| var i = constant_op.constant(0, name: "i"); | |||||
| var c = new Func<Tensor, Tensor>(x => tf.less(x, 10, name: "c")); | |||||
| var b = new Func<Tensor, Tensor>(x => tf.add(x, 1, name: "c")); | |||||
| var r = control_flow_ops.while_loop(c, b, new[] { i }); | |||||
| } | |||||
| private void _testWhileContextHelper(int? maximum_iterations = null) | private void _testWhileContextHelper(int? maximum_iterations = null) | ||||
| { | { | ||||
| // TODO: implement missing code dependencies | // TODO: implement missing code dependencies | ||||
| @@ -17,7 +29,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||||
| var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, 10, name: "c")); | var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, 10, name: "c")); | ||||
| var b = new Func<Tensor, Tensor>(x => gen_math_ops.add(x, 1, name: "c")); | var b = new Func<Tensor, Tensor>(x => gen_math_ops.add(x, 1, name: "c")); | ||||
| control_flow_ops.while_loop( | control_flow_ops.while_loop( | ||||
| c, b, new[] { i }, maximum_iterations); | |||||
| c, b, new[] { i }, maximum_iterations: maximum_iterations); | |||||
| foreach (Operation op in sess.graph.get_operations()) | foreach (Operation op in sess.graph.get_operations()) | ||||
| { | { | ||||
| var control_flow_context = op._get_control_flow_context(); | var control_flow_context = op._get_control_flow_context(); | ||||