| @@ -37,7 +37,7 @@ namespace Tensorflow | |||
| public Operation group<T>(T[] inputs, string name = null) where T : ITensorOrOperation | |||
| => control_flow_ops.group(inputs, name: name); | |||
| public Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars, | |||
| /*public Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars, | |||
| TensorShape shape_invariants = null, | |||
| int parallel_iterations = 10, | |||
| bool back_prop = true, | |||
| @@ -52,7 +52,7 @@ namespace Tensorflow | |||
| swap_memory: swap_memory, | |||
| name: name, | |||
| maximum_iterations: maximum_iterations, | |||
| return_same_structure: return_same_structure); | |||
| return_same_structure: return_same_structure);*/ | |||
| public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) | |||
| => ops.control_dependencies(control_inputs); | |||
| @@ -0,0 +1,25 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Operations | |||
| { | |||
| internal class LoopVar<TItem> | |||
| { | |||
| public Tensor Counter { get; } | |||
| public TItem[] Items { get; } | |||
| public TItem Item { get; } | |||
| public LoopVar(Tensor counter, TItem[] items) | |||
| { | |||
| Counter = counter; | |||
| Items = items; | |||
| } | |||
| public LoopVar(Tensor counter, TItem item) | |||
| { | |||
| Counter = counter; | |||
| Item = item; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,32 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Operations | |||
| { | |||
| internal class BodyItemInRnnWhileLoop | |||
| { | |||
| /// <summary> | |||
| /// int32 scalar Tensor. | |||
| /// </summary> | |||
| public Tensor time { get; set; } | |||
| /// <summary> | |||
| /// List of `TensorArray`s that represent the output. | |||
| /// </summary> | |||
| public TensorArray[] output_ta_t { get; set; } | |||
| /// <summary> | |||
| /// nested tuple of vector tensors that represent the state. | |||
| /// </summary> | |||
| public Tensor state { get; set; } | |||
| public BodyItemInRnnWhileLoop(Tensor time, TensorArray[] output_ta_t, Tensor state) | |||
| { | |||
| this.time = time; | |||
| this.output_ta_t = output_ta_t; | |||
| this.state = state; | |||
| } | |||
| public static implicit operator (Tensor, TensorArray[], Tensor)(BodyItemInRnnWhileLoop item) | |||
| => (item.time, item.output_ta_t, item.state); | |||
| } | |||
| } | |||
| @@ -145,7 +145,7 @@ namespace Tensorflow.Operations | |||
| { | |||
| var ta = new TensorArray(dtype: dtype_, | |||
| size: time_steps, | |||
| element_shape: new[] { element_shape }, | |||
| element_shape: element_shape, | |||
| tensor_array_name: base_name + name); | |||
| return ta; | |||
| }; | |||
| @@ -178,19 +178,29 @@ namespace Tensorflow.Operations | |||
| // Make sure that we run at least 1 step, if necessary, to ensure | |||
| // the TensorArrays pick up the dynamic shape. | |||
| Tensor loop_bound; | |||
| Tensor loop_bound = null; | |||
| if (in_graph_mode) | |||
| loop_bound = math_ops.minimum( | |||
| time_steps, math_ops.maximum(1, max_sequence_length)); | |||
| /*Func<Tensor, Tensor> cond = (ctime) => | |||
| Func<BodyItemInRnnWhileLoop, Tensor> cond = (item) => | |||
| { | |||
| return null; | |||
| return time < loop_bound; | |||
| }; | |||
| control_flow_ops.while_loop( | |||
| // Take a time step of the dynamic RNN. | |||
| Func<BodyItemInRnnWhileLoop, BodyItemInRnnWhileLoop> _time_step = (item) => | |||
| { | |||
| return item; | |||
| }; | |||
| control_flow_ops.while_loop<BodyItemInRnnWhileLoop>( | |||
| cond: cond, | |||
| body = );*/ | |||
| body: _time_step, | |||
| loop_vars: new BodyItemInRnnWhileLoop(time, output_ta.ToArray(), state), | |||
| parallel_iterations: parallel_iterations, | |||
| maximum_iterations: time_steps, | |||
| swap_memory: swap_memory); | |||
| throw new NotImplementedException(""); | |||
| } | |||
| @@ -39,7 +39,7 @@ namespace Tensorflow.Operations | |||
| public TensorArray(TF_DataType dtype, Tensor size = default, bool? clear_after_read = null, bool? dynamic_size = null, | |||
| string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | |||
| bool infer_shape = true, TensorShape[] element_shape = null, | |||
| bool infer_shape = true, TensorShape element_shape = null, | |||
| bool colocate_with_first_write_call = true, string name = null) | |||
| { | |||
| _implementation = new _GraphTensorArray(dtype, | |||
| @@ -44,7 +44,7 @@ namespace Tensorflow.Operations | |||
| public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null, | |||
| bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | |||
| bool infer_shape = true, TensorShape[] element_shape = null, | |||
| bool infer_shape = true, TensorShape element_shape = null, | |||
| bool colocate_with_first_write_call = true, string name = null) | |||
| { | |||
| clear_after_read = clear_after_read ?? true; | |||
| @@ -68,7 +68,7 @@ namespace Tensorflow.Operations | |||
| else | |||
| { | |||
| _infer_shape = true; | |||
| _element_shape = new List<TensorShape> { }; | |||
| _element_shape = new List<TensorShape> { element_shape }; | |||
| } | |||
| tf_with(ops.name_scope(name, "TensorArray", new { handle, size, flow }), scope => | |||
| @@ -135,7 +135,7 @@ namespace Tensorflow.Operations | |||
| var ta = new TensorArray(_dtype, | |||
| infer_shape:_infer_shape, | |||
| element_shape: _element_shape.ToArray(), | |||
| element_shape: _element_shape[0], | |||
| dynamic_size: _dynamic_size, | |||
| handle: _handle, | |||
| flow: flow_out, | |||
| @@ -485,7 +485,7 @@ namespace Tensorflow | |||
| }); | |||
| } | |||
| public static Tensor[] _convert_flows_to_tensorarrays<T>(T[] tensors_or_tensorarrays, Tensor[] tensors_or_flows) | |||
| public static Tensor[] _convert_flows_to_tensorarrays<T>(T tensors_or_tensorarrays, Tensor[] tensors_or_flows) | |||
| { | |||
| // zip(tensors_or_tensorarrays, tensors_or_flows).Select((ta, t_or_flow) => ta).ToArray(); | |||
| return tensors_or_flows; | |||
| @@ -591,18 +591,18 @@ namespace Tensorflow | |||
| /// <param name="body"></param> | |||
| /// <param name="loop_vars"></param> | |||
| /// <param name="i"></param> | |||
| public static Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars, | |||
| public static Tensor while_loop<TItem>(Func<TItem, Tensor> cond, Func<TItem, TItem> body, TItem loop_vars, | |||
| TensorShape shape_invariants = null, | |||
| int parallel_iterations = 10, | |||
| bool back_prop = true, | |||
| bool swap_memory = false, | |||
| string name = null, | |||
| int? maximum_iterations = null, | |||
| Tensor maximum_iterations = null, | |||
| bool return_same_structure = false) | |||
| { | |||
| tf_with(ops.name_scope(name, "while", loop_vars), scope => | |||
| { | |||
| if (loop_vars == null || loop_vars.Length == 0) | |||
| if (loop_vars == null) | |||
| throw new ValueError("No loop variables provided"); | |||
| if (cond == null) | |||
| throw new ValueError("cond must be callable."); | |||
| @@ -611,6 +611,28 @@ namespace Tensorflow | |||
| if (parallel_iterations < 1) | |||
| throw new ValueError("parallel_iterations must be a positive integer."); | |||
| var try_to_pack = loop_vars is Tensor && !return_same_structure; | |||
| var counter = constant_op.constant(0, dtype: maximum_iterations.dtype, name: "iteration_counter"); | |||
| var orig_cond = cond; | |||
| var orig_body = body; | |||
| LoopVar<TItem> loop_vars_1 = null; | |||
| Func<Tensor, TItem, LoopVar<TItem>> body_buildloop = null; | |||
| Func<Tensor, TItem, Tensor> cond_buildloop = null; | |||
| if (try_to_pack) | |||
| { | |||
| } | |||
| else | |||
| { | |||
| loop_vars_1 = new LoopVar<TItem>(counter, loop_vars); | |||
| cond_buildloop = (i, lv) => | |||
| math_ops.logical_and(i < maximum_iterations, orig_cond(lv)); | |||
| body_buildloop = (i, lv) => new LoopVar<TItem>(i + 1, orig_body(lv)); | |||
| } | |||
| try_to_pack = false; | |||
| var loop_context = new WhileContext( | |||
| maximum_iterations: maximum_iterations, | |||
| parallel_iterations: parallel_iterations, | |||
| @@ -620,7 +642,7 @@ namespace Tensorflow | |||
| if (loop_context.outer_context == null) | |||
| ops.add_to_collection(tf.GraphKeys.WHILE_CONTEXT, loop_context); | |||
| var results = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants, | |||
| var results = loop_context.BuildLoop(cond_buildloop, body_buildloop, loop_vars, shape_invariants, | |||
| return_same_structure); | |||
| if (maximum_iterations != null) | |||
| @@ -28,12 +28,9 @@ namespace Tensorflow | |||
| } | |||
| public static (Tensor, Tensor) tensor_array_v3<T>(T size, TF_DataType dtype = TF_DataType.DtInvalid, | |||
| TensorShape[] element_shape = null, bool dynamic_size = false, bool clear_after_read = true, | |||
| bool identical_element_shapes = false, string tensor_array_name = "tensor_array_name", string name = null) | |||
| TensorShape element_shape = null, bool dynamic_size = false, bool clear_after_read = true, | |||
| bool identical_element_shapes = false, string tensor_array_name = "", string name = null) | |||
| { | |||
| if (tensor_array_name == null) | |||
| tensor_array_name = string.Empty; | |||
| var _op = _op_def_lib._apply_op_helper("TensorArrayV3", name, new | |||
| { | |||
| size, | |||
| @@ -223,7 +223,6 @@ namespace Tensorflow.Util | |||
| private static void _flatten_recursive<T>(T obj, List<T> list) | |||
| { | |||
| switch(obj) | |||
| { | |||
| case IDictionary dict: | |||
| @@ -93,14 +93,14 @@ namespace Tensorflow | |||
| return new Session().as_default(); | |||
| } | |||
| public Session Session(Graph graph, SessionOptions opts = null) | |||
| public Session Session(Graph graph, ConfigProto config = null) | |||
| { | |||
| return new Session(graph, opts: opts).as_default(); | |||
| return new Session(graph, config: config).as_default(); | |||
| } | |||
| public Session Session(SessionOptions opts) | |||
| public Session Session(ConfigProto config) | |||
| { | |||
| return new Session(null, opts).as_default(); | |||
| return new Session(null, config).as_default(); | |||
| } | |||
| public void __init__() | |||
| @@ -25,9 +25,8 @@ namespace TensorFlowNET.UnitTest | |||
| { | |||
| lock (Locks.ProcessWide) | |||
| { | |||
| var opts = new SessionOptions(); | |||
| opts.SetConfig(new ConfigProto {InterOpParallelismThreads = 4}); | |||
| session_ = new Session(graph, opts, s); | |||
| var config = new ConfigProto {InterOpParallelismThreads = 4}; | |||
| session_ = new Session(graph, config, s); | |||
| } | |||
| } | |||
| @@ -18,10 +18,10 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||
| var i = constant_op.constant(0, name: "i"); | |||
| var c = new Func<Tensor, Tensor>(x => tf.less(x, 10, name: "c")); | |||
| var b = new Func<Tensor, Tensor>(x => tf.add(x, 1, name: "c")); | |||
| var r = control_flow_ops.while_loop(c, b, new[] { i }); | |||
| var r = control_flow_ops.while_loop(c, b, i); | |||
| } | |||
| private void _testWhileContextHelper(int? maximum_iterations = null) | |||
| private void _testWhileContextHelper(int maximum_iterations) | |||
| { | |||
| // TODO: implement missing code dependencies | |||
| using (var sess = this.cached_session()) | |||
| @@ -30,7 +30,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||
| var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, 10, name: "c")); | |||
| var b = new Func<Tensor, Tensor>(x => gen_math_ops.add(x, 1, name: "c")); | |||
| control_flow_ops.while_loop( | |||
| c, b, new[] { i }, maximum_iterations: maximum_iterations); | |||
| c, b, i , maximum_iterations: tf.constant(maximum_iterations)); | |||
| foreach (Operation op in sess.graph.get_operations()) | |||
| { | |||
| var control_flow_context = op._get_control_flow_context(); | |||
| @@ -42,13 +42,6 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||
| } | |||
| } | |||
| [Ignore("TODO")] | |||
| [TestMethod] | |||
| public void testWhileContext() | |||
| { | |||
| _testWhileContextHelper(); | |||
| } | |||
| [Ignore("TODO")] | |||
| [TestMethod] | |||
| public void testWhileContextWithMaximumIterations() | |||