From 593ce2b6c3a1481b0cc148086035facc5708a57d Mon Sep 17 00:00:00 2001 From: Brendan Mulcahy Date: Fri, 29 Nov 2019 16:35:11 -0500 Subject: [PATCH] Adjust types to get while_loop working --- .../Operations/functional_ops.cs | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/functional_ops.cs b/src/TensorFlowNET.Core/Operations/functional_ops.cs index f392d766..68e56fb9 100644 --- a/src/TensorFlowNET.Core/Operations/functional_ops.cs +++ b/src/TensorFlowNET.Core/Operations/functional_ops.cs @@ -39,11 +39,11 @@ namespace Tensorflow bool input_is_sequence = nest.is_sequence(elems); List input_flatten(Tensor x) => input_is_sequence ? nest.flatten(x) : new List {x}; - object input_pack(List x) => input_is_sequence ? nest.pack_sequence_as(elems, x) : x[0]; + Tensor input_pack(List x) => input_is_sequence ? (Tensor)nest.pack_sequence_as(elems, x) : x[0]; bool output_is_sequence; Func> output_flatten; - Func, object> output_pack; + Func, Tensor> output_pack; if (initializer == null) { output_is_sequence = input_is_sequence; @@ -54,7 +54,7 @@ namespace Tensorflow { output_is_sequence = nest.is_sequence(initializer); output_flatten = (x) => output_is_sequence ? nest.flatten(x) : new List {x}; - output_pack = (x) => output_is_sequence ? nest.pack_sequence_as(initializer, x) : x[0]; + output_pack = (x) => output_is_sequence ? (Tensor)nest.pack_sequence_as(initializer, x) : x[0]; } var elems_flat = input_flatten(elems); @@ -130,8 +130,11 @@ namespace Tensorflow } } - (int, List, List) compute(int _i, List a_flat_, List tas) + (int, List, List) compute(ValueTuple, List> tuple) { + + (int _i, List a_flat_, List tas) = tuple; + var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(tf.constant(_i))).ToList()); var packed_a = output_pack(a_flat_); var a_out = fn((Tensor)packed_a, (Tensor)packed_elems); // todo brendan are these casts legal? @@ -147,19 +150,19 @@ namespace Tensorflow } int initial_i; - Func condition; + Func<(int, List, List), Tensor> condition; if (reverse) { initial_i = n - 1 - i; - condition = x => tf.constant(x >= 0); + condition = x => tf.constant(x.Item1 >= 0); } else { initial_i = i; - condition = x => tf.convert_to_tensor(x < n); + condition = x => tf.constant(x.Item1 < n); } - List r_a = + (_, _, List r_a) = control_flow_ops.while_loop( condition, compute, @@ -167,7 +170,7 @@ namespace Tensorflow parallel_iterations: parallel_iterations, back_prop: back_prop, swap_memory: swap_memory, - maximum_iterations: n); + maximum_iterations: tf.constant(n)); var results_flat = r_a.Select(r => r.stack()).ToList();