| @@ -39,11 +39,11 @@ namespace Tensorflow | |||
| bool input_is_sequence = nest.is_sequence(elems); | |||
| List<Tensor> input_flatten(Tensor x) => input_is_sequence ? nest.flatten(x) : new List<Tensor> {x}; | |||
| object input_pack(List<Tensor> x) => input_is_sequence ? nest.pack_sequence_as(elems, x) : x[0]; | |||
| Tensor input_pack(List<Tensor> x) => input_is_sequence ? (Tensor)nest.pack_sequence_as(elems, x) : x[0]; | |||
| bool output_is_sequence; | |||
| Func<Tensor, List<Tensor>> output_flatten; | |||
| Func<List<Tensor>, object> output_pack; | |||
| Func<List<Tensor>, Tensor> output_pack; | |||
| if (initializer == null) | |||
| { | |||
| output_is_sequence = input_is_sequence; | |||
| @@ -54,7 +54,7 @@ namespace Tensorflow | |||
| { | |||
| output_is_sequence = nest.is_sequence(initializer); | |||
| output_flatten = (x) => output_is_sequence ? nest.flatten(x) : new List<Tensor> {x}; | |||
| output_pack = (x) => output_is_sequence ? nest.pack_sequence_as(initializer, x) : x[0]; | |||
| output_pack = (x) => output_is_sequence ? (Tensor)nest.pack_sequence_as(initializer, x) : x[0]; | |||
| } | |||
| var elems_flat = input_flatten(elems); | |||
| @@ -130,8 +130,11 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| (int, List<Tensor>, List<TensorArray>) compute(int _i, List<Tensor> a_flat_, List<TensorArray> tas) | |||
| (int, List<Tensor>, List<TensorArray>) compute(ValueTuple<int, List<Tensor>, List<TensorArray>> tuple) | |||
| { | |||
| (int _i, List<Tensor> a_flat_, List<TensorArray> tas) = tuple; | |||
| var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(tf.constant(_i))).ToList()); | |||
| var packed_a = output_pack(a_flat_); | |||
| var a_out = fn((Tensor)packed_a, (Tensor)packed_elems); // todo brendan are these casts legal? | |||
| @@ -147,19 +150,19 @@ namespace Tensorflow | |||
| } | |||
| int initial_i; | |||
| Func<int, Tensor> condition; | |||
| Func<(int, List<Tensor>, List<TensorArray>), Tensor> condition; | |||
| if (reverse) | |||
| { | |||
| initial_i = n - 1 - i; | |||
| condition = x => tf.constant(x >= 0); | |||
| condition = x => tf.constant(x.Item1 >= 0); | |||
| } | |||
| else | |||
| { | |||
| initial_i = i; | |||
| condition = x => tf.convert_to_tensor(x < n); | |||
| condition = x => tf.constant(x.Item1 < n); | |||
| } | |||
| List<TensorArray> r_a = | |||
| (_, _, List<TensorArray> r_a) = | |||
| control_flow_ops.while_loop( | |||
| condition, | |||
| compute, | |||
| @@ -167,7 +170,7 @@ namespace Tensorflow | |||
| parallel_iterations: parallel_iterations, | |||
| back_prop: back_prop, | |||
| swap_memory: swap_memory, | |||
| maximum_iterations: n); | |||
| maximum_iterations: tf.constant(n)); | |||
| var results_flat = r_a.Select(r => r.stack()).ToList(); | |||