| @@ -17,6 +17,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using NumSharp; | |||
| using Tensorflow.Framework; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| @@ -128,60 +129,57 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| (int, List<Tensor>, List<TensorArray>) compute(ValueTuple<int, List<Tensor>, List<TensorArray>> tuple) | |||
| BodyItem compute(BodyItem item) | |||
| { | |||
| (int _i, List<Tensor> a_flat_, List<TensorArray> tas) = tuple; | |||
| var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(tf.constant(_i))).ToList()); | |||
| var packed_a = output_pack(a_flat_); | |||
| var a_out = fn((Tensor)packed_a, (Tensor)packed_elems); // todo brendan are these casts legal? | |||
| var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(tf.constant(item.I))).ToList()); | |||
| var packed_a = output_pack(item.A_Flat); | |||
| var a_out = fn(packed_a, packed_elems); | |||
| var flat_a_out = output_flatten(a_out); | |||
| for (int j = 0; j < tas.Count; j++) | |||
| for (int j = 0; j < item.Accs_ta.Count; j++) | |||
| { | |||
| tas[j].write(tf.constant(i), flat_a_out[j]); | |||
| item.Accs_ta[j].write(tf.constant(i), flat_a_out[j]); | |||
| } | |||
| var next_i = reverse ? _i-- : _i++; | |||
| return (next_i, flat_a_out, tas); | |||
| var next_i = reverse ? item.I - 1 : item.I + 1; | |||
| return new BodyItem(next_i, flat_a_out, item.Accs_ta); | |||
| } | |||
| int initial_i; | |||
| Func<(int, List<Tensor>, List<TensorArray>), Tensor> condition; | |||
| Func<BodyItem, Tensor> condition; | |||
| if (reverse) | |||
| { | |||
| initial_i = n - 1 - i; | |||
| condition = x => tf.constant(x.Item1 >= 0); | |||
| condition = x => tf.constant(x.I >= 0); | |||
| } | |||
| else | |||
| { | |||
| initial_i = i; | |||
| condition = x => tf.constant(x.Item1 < n); | |||
| condition = x => tf.constant(x.I < n); | |||
| } | |||
| (_, _, List<TensorArray> r_a) = | |||
| BodyItem bodyItem = | |||
| control_flow_ops.while_loop( | |||
| condition, | |||
| compute, | |||
| (initial_i, a_flat, accs_ta), | |||
| new BodyItem(tf.constant(initial_i), a_flat, accs_ta), | |||
| parallel_iterations: parallel_iterations, | |||
| back_prop: back_prop, | |||
| swap_memory: swap_memory, | |||
| maximum_iterations: tf.constant(n)); | |||
| var results_flat = r_a.Select(r => r.stack()).ToList(); | |||
| var results_flat = bodyItem.Accs_ta.Select(r => r.stack()).ToList(); | |||
| var n_static = new Dimension(tensor_shape.dimension_value(elems_flat[0].shape[0])); | |||
| var n_static = new Dimension(tensor_shape.dimension_value(elems_flat[0].TensorShape.with_rank_at_least(1).dims[0])); | |||
| foreach (var elem in elems_flat.Skip(1)) | |||
| { | |||
| n_static.merge_with(new Dimension(tensor_shape.dimension_value(elem.shape[0]))); | |||
| n_static.merge_with(new Dimension(tensor_shape.dimension_value(elem.TensorShape.with_rank_at_least(1).dims[0]))); | |||
| } | |||
| foreach (Tensor r in results_flat) | |||
| { | |||
| r.set_shape(new TensorShape(n_static).concatenate(r.shape.Skip(1).ToArray())); | |||
| r.set_shape(new TensorShape(n_static).concatenate(r.TensorShape[new Slice("1:")])); | |||
| } | |||
| // todo get working when the above caching_device is fixed | |||
| @@ -192,6 +190,37 @@ namespace Tensorflow | |||
| return output_pack(results_flat); | |||
| }); | |||
| } | |||
| internal class BodyItem : ICanBeFlattened, IPackable<BodyItem> | |||
| { | |||
| public Tensor I { get; set; } | |||
| public List<Tensor> A_Flat { get; set; } | |||
| public List<TensorArray> Accs_ta { get; set; } | |||
| public BodyItem(Tensor i, List<Tensor> a_flat, List<TensorArray> accs_ta) | |||
| { | |||
| I = i; | |||
| A_Flat = a_flat; | |||
| Accs_ta = accs_ta; | |||
| } | |||
| public object[] Flatten() | |||
| { | |||
| var elements = new List<object> { I }; | |||
| elements.AddRange(A_Flat); | |||
| elements.AddRange(Accs_ta); | |||
| return elements.ToArray(); | |||
| } | |||
| public BodyItem Pack(object[] sequences) | |||
| { | |||
| I = sequences[0] as Tensor; | |||
| A_Flat = new List<Tensor> { sequences[1] as Tensor }; | |||
| Accs_ta = new List<TensorArray> { sequences[2] as TensorArray }; | |||
| return new BodyItem(I, A_Flat, Accs_ta); | |||
| } | |||
| } | |||
| } | |||
| } | |||