diff --git a/src/TensorFlowNET.Core/Operations/map_fn.cs b/src/TensorFlowNET.Core/Operations/map_fn.cs new file mode 100644 index 00000000..1206d5b9 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/map_fn.cs @@ -0,0 +1,86 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Operations; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public partial class Operation + { + /// + /// map on the list of tensors unpacked from `elems` on dimension 0. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// A tensor or (possibly nested) sequence of tensors. + public static Tensor map_fn(Func fn, + Tensor elems, + TF_DataType dtype = TF_DataType.DtInvalid, + int parallel_iterations = 10, + bool back_prop = true, + bool swap_memory = false, + bool infer_shape = true, + string name = null) + { + var elems_flat = new[] { elems }; + tf_with(ops.name_scope(name, "map", elems_flat), delegate + { + var varscope = tf.get_variable_scope(); + elems_flat = elems_flat.Select(elem => ops.convert_to_tensor(elem, name: "elem")) + .ToArray(); + + dtype = elems_flat.Select(elem => elem.dtype).First(); + var dtype_flat = new[] { dtype }; + + // Convert elems to tensor array. n may be known statically. + var static_shape = elems_flat[0].shape; + + var n = static_shape[0]; + + // TensorArrays are always flat + var elems_ta = elems_flat.Select(elem => new TensorArray(dtype: elem.dtype, + size: ops.convert_to_tensor(n), + dynamic_size: false, + infer_shape: true)).ToArray(); + + // Unpack elements + var elems_ta_1 = new List(); + foreach (var (elem_ta, elem) in zip(elems_ta, elems_flat)) + elems_ta_1.Add(elem_ta.unstack(elem)); + + elems_ta = elems_ta_1.ToArray(); + + var i = constant_op.constant(0); + + var accs_ta = dtype_flat.Select(dt => new TensorArray(dtype: dt, + size: ops.convert_to_tensor(n), + dynamic_size: false, + infer_shape: infer_shape)).ToArray(); + + /*Func compute = (i, tas) => + { + throw new NotImplementedException(""); + }; + + var r_a = control_flow_ops.while_loop( + (i, _) => i < n, + compute, + new[] { i, accs_ta }, + parallel_iterations: parallel_iterations, + back_prop: back_prop, + swap_memory: swap_memory, + maximum_iterations: n);*/ + }); + + throw new NotImplementedException(""); + } + } +}