| @@ -39,11 +39,11 @@ namespace Tensorflow | |||||
| bool input_is_sequence = nest.is_sequence(elems); | bool input_is_sequence = nest.is_sequence(elems); | ||||
| List<Tensor> input_flatten(Tensor x) => input_is_sequence ? nest.flatten(x) : new List<Tensor> {x}; | List<Tensor> input_flatten(Tensor x) => input_is_sequence ? nest.flatten(x) : new List<Tensor> {x}; | ||||
| object input_pack(List<Tensor> x) => input_is_sequence ? nest.pack_sequence_as(elems, x) : x[0]; | |||||
| Tensor input_pack(List<Tensor> x) => input_is_sequence ? (Tensor)nest.pack_sequence_as(elems, x) : x[0]; | |||||
| bool output_is_sequence; | bool output_is_sequence; | ||||
| Func<Tensor, List<Tensor>> output_flatten; | Func<Tensor, List<Tensor>> output_flatten; | ||||
| Func<List<Tensor>, object> output_pack; | |||||
| Func<List<Tensor>, Tensor> output_pack; | |||||
| if (initializer == null) | if (initializer == null) | ||||
| { | { | ||||
| output_is_sequence = input_is_sequence; | output_is_sequence = input_is_sequence; | ||||
| @@ -54,7 +54,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| output_is_sequence = nest.is_sequence(initializer); | output_is_sequence = nest.is_sequence(initializer); | ||||
| output_flatten = (x) => output_is_sequence ? nest.flatten(x) : new List<Tensor> {x}; | output_flatten = (x) => output_is_sequence ? nest.flatten(x) : new List<Tensor> {x}; | ||||
| output_pack = (x) => output_is_sequence ? nest.pack_sequence_as(initializer, x) : x[0]; | |||||
| output_pack = (x) => output_is_sequence ? (Tensor)nest.pack_sequence_as(initializer, x) : x[0]; | |||||
| } | } | ||||
| var elems_flat = input_flatten(elems); | var elems_flat = input_flatten(elems); | ||||
| @@ -130,8 +130,11 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| (int, List<Tensor>, List<TensorArray>) compute(int _i, List<Tensor> a_flat_, List<TensorArray> tas) | |||||
| (int, List<Tensor>, List<TensorArray>) compute(ValueTuple<int, List<Tensor>, List<TensorArray>> tuple) | |||||
| { | { | ||||
| (int _i, List<Tensor> a_flat_, List<TensorArray> tas) = tuple; | |||||
| var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(tf.constant(_i))).ToList()); | var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(tf.constant(_i))).ToList()); | ||||
| var packed_a = output_pack(a_flat_); | var packed_a = output_pack(a_flat_); | ||||
| var a_out = fn((Tensor)packed_a, (Tensor)packed_elems); // todo brendan are these casts legal? | var a_out = fn((Tensor)packed_a, (Tensor)packed_elems); // todo brendan are these casts legal? | ||||
| @@ -147,19 +150,19 @@ namespace Tensorflow | |||||
| } | } | ||||
| int initial_i; | int initial_i; | ||||
| Func<int, Tensor> condition; | |||||
| Func<(int, List<Tensor>, List<TensorArray>), Tensor> condition; | |||||
| if (reverse) | if (reverse) | ||||
| { | { | ||||
| initial_i = n - 1 - i; | initial_i = n - 1 - i; | ||||
| condition = x => tf.constant(x >= 0); | |||||
| condition = x => tf.constant(x.Item1 >= 0); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| initial_i = i; | initial_i = i; | ||||
| condition = x => tf.convert_to_tensor(x < n); | |||||
| condition = x => tf.constant(x.Item1 < n); | |||||
| } | } | ||||
| List<TensorArray> r_a = | |||||
| (_, _, List<TensorArray> r_a) = | |||||
| control_flow_ops.while_loop( | control_flow_ops.while_loop( | ||||
| condition, | condition, | ||||
| compute, | compute, | ||||
| @@ -167,7 +170,7 @@ namespace Tensorflow | |||||
| parallel_iterations: parallel_iterations, | parallel_iterations: parallel_iterations, | ||||
| back_prop: back_prop, | back_prop: back_prop, | ||||
| swap_memory: swap_memory, | swap_memory: swap_memory, | ||||
| maximum_iterations: n); | |||||
| maximum_iterations: tf.constant(n)); | |||||
| var results_flat = r_a.Select(r => r.stack()).ToList(); | var results_flat = r_a.Select(r => r.stack()).ToList(); | ||||