| @@ -17,6 +17,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using NumSharp; | |||||
| using Tensorflow.Framework; | using Tensorflow.Framework; | ||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -128,60 +129,57 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| (int, List<Tensor>, List<TensorArray>) compute(ValueTuple<int, List<Tensor>, List<TensorArray>> tuple) | |||||
| BodyItem compute(BodyItem item) | |||||
| { | { | ||||
| (int _i, List<Tensor> a_flat_, List<TensorArray> tas) = tuple; | |||||
| var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(tf.constant(_i))).ToList()); | |||||
| var packed_a = output_pack(a_flat_); | |||||
| var a_out = fn((Tensor)packed_a, (Tensor)packed_elems); // todo brendan are these casts legal? | |||||
| var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(tf.constant(item.I))).ToList()); | |||||
| var packed_a = output_pack(item.A_Flat); | |||||
| var a_out = fn(packed_a, packed_elems); | |||||
| var flat_a_out = output_flatten(a_out); | var flat_a_out = output_flatten(a_out); | ||||
| for (int j = 0; j < tas.Count; j++) | |||||
| for (int j = 0; j < item.Accs_ta.Count; j++) | |||||
| { | { | ||||
| tas[j].write(tf.constant(i), flat_a_out[j]); | |||||
| item.Accs_ta[j].write(tf.constant(i), flat_a_out[j]); | |||||
| } | } | ||||
| var next_i = reverse ? _i-- : _i++; | |||||
| return (next_i, flat_a_out, tas); | |||||
| var next_i = reverse ? item.I - 1 : item.I + 1; | |||||
| return new BodyItem(next_i, flat_a_out, item.Accs_ta); | |||||
| } | } | ||||
| int initial_i; | int initial_i; | ||||
| Func<(int, List<Tensor>, List<TensorArray>), Tensor> condition; | |||||
| Func<BodyItem, Tensor> condition; | |||||
| if (reverse) | if (reverse) | ||||
| { | { | ||||
| initial_i = n - 1 - i; | initial_i = n - 1 - i; | ||||
| condition = x => tf.constant(x.Item1 >= 0); | |||||
| condition = x => tf.constant(x.I >= 0); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| initial_i = i; | initial_i = i; | ||||
| condition = x => tf.constant(x.Item1 < n); | |||||
| condition = x => tf.constant(x.I < n); | |||||
| } | } | ||||
| (_, _, List<TensorArray> r_a) = | |||||
| BodyItem bodyItem = | |||||
| control_flow_ops.while_loop( | control_flow_ops.while_loop( | ||||
| condition, | condition, | ||||
| compute, | compute, | ||||
| (initial_i, a_flat, accs_ta), | |||||
| new BodyItem(tf.constant(initial_i), a_flat, accs_ta), | |||||
| parallel_iterations: parallel_iterations, | parallel_iterations: parallel_iterations, | ||||
| back_prop: back_prop, | back_prop: back_prop, | ||||
| swap_memory: swap_memory, | swap_memory: swap_memory, | ||||
| maximum_iterations: tf.constant(n)); | maximum_iterations: tf.constant(n)); | ||||
| var results_flat = r_a.Select(r => r.stack()).ToList(); | |||||
| var results_flat = bodyItem.Accs_ta.Select(r => r.stack()).ToList(); | |||||
| var n_static = new Dimension(tensor_shape.dimension_value(elems_flat[0].shape[0])); | |||||
| var n_static = new Dimension(tensor_shape.dimension_value(elems_flat[0].TensorShape.with_rank_at_least(1).dims[0])); | |||||
| foreach (var elem in elems_flat.Skip(1)) | foreach (var elem in elems_flat.Skip(1)) | ||||
| { | { | ||||
| n_static.merge_with(new Dimension(tensor_shape.dimension_value(elem.shape[0]))); | |||||
| n_static.merge_with(new Dimension(tensor_shape.dimension_value(elem.TensorShape.with_rank_at_least(1).dims[0]))); | |||||
| } | } | ||||
| foreach (Tensor r in results_flat) | foreach (Tensor r in results_flat) | ||||
| { | { | ||||
| r.set_shape(new TensorShape(n_static).concatenate(r.shape.Skip(1).ToArray())); | |||||
| r.set_shape(new TensorShape(n_static).concatenate(r.TensorShape[new Slice("1:")])); | |||||
| } | } | ||||
| // todo get working when the above caching_device is fixed | // todo get working when the above caching_device is fixed | ||||
| @@ -192,6 +190,37 @@ namespace Tensorflow | |||||
| return output_pack(results_flat); | return output_pack(results_flat); | ||||
| }); | }); | ||||
| } | } | ||||
| internal class BodyItem : ICanBeFlattened, IPackable<BodyItem> | |||||
| { | |||||
| public Tensor I { get; set; } | |||||
| public List<Tensor> A_Flat { get; set; } | |||||
| public List<TensorArray> Accs_ta { get; set; } | |||||
| public BodyItem(Tensor i, List<Tensor> a_flat, List<TensorArray> accs_ta) | |||||
| { | |||||
| I = i; | |||||
| A_Flat = a_flat; | |||||
| Accs_ta = accs_ta; | |||||
| } | |||||
| public object[] Flatten() | |||||
| { | |||||
| var elements = new List<object> { I }; | |||||
| elements.AddRange(A_Flat); | |||||
| elements.AddRange(Accs_ta); | |||||
| return elements.ToArray(); | |||||
| } | |||||
| public BodyItem Pack(object[] sequences) | |||||
| { | |||||
| I = sequences[0] as Tensor; | |||||
| A_Flat = new List<Tensor> { sequences[1] as Tensor }; | |||||
| Accs_ta = new List<TensorArray> { sequences[2] as TensorArray }; | |||||
| return new BodyItem(I, A_Flat, Accs_ta); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||