| @@ -0,0 +1,86 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Tensorflow.Operations; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public partial class Operation | |||||
| { | |||||
| /// <summary> | |||||
| /// map on the list of tensors unpacked from `elems` on dimension 0. | |||||
| /// </summary> | |||||
| /// <param name="fn"></param> | |||||
| /// <param name="elems"></param> | |||||
| /// <param name="dtype"></param> | |||||
| /// <param name="parallel_iterations"></param> | |||||
| /// <param name="back_prop"></param> | |||||
| /// <param name="swap_memory"></param> | |||||
| /// <param name="infer_shape"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns>A tensor or (possibly nested) sequence of tensors.</returns> | |||||
| public static Tensor map_fn(Func<Tensor, Tensor> fn, | |||||
| Tensor elems, | |||||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||||
| int parallel_iterations = 10, | |||||
| bool back_prop = true, | |||||
| bool swap_memory = false, | |||||
| bool infer_shape = true, | |||||
| string name = null) | |||||
| { | |||||
| var elems_flat = new[] { elems }; | |||||
| tf_with(ops.name_scope(name, "map", elems_flat), delegate | |||||
| { | |||||
| var varscope = tf.get_variable_scope(); | |||||
| elems_flat = elems_flat.Select(elem => ops.convert_to_tensor(elem, name: "elem")) | |||||
| .ToArray(); | |||||
| dtype = elems_flat.Select(elem => elem.dtype).First(); | |||||
| var dtype_flat = new[] { dtype }; | |||||
| // Convert elems to tensor array. n may be known statically. | |||||
| var static_shape = elems_flat[0].shape; | |||||
| var n = static_shape[0]; | |||||
| // TensorArrays are always flat | |||||
| var elems_ta = elems_flat.Select(elem => new TensorArray(dtype: elem.dtype, | |||||
| size: ops.convert_to_tensor(n), | |||||
| dynamic_size: false, | |||||
| infer_shape: true)).ToArray(); | |||||
| // Unpack elements | |||||
| var elems_ta_1 = new List<TensorArray>(); | |||||
| foreach (var (elem_ta, elem) in zip(elems_ta, elems_flat)) | |||||
| elems_ta_1.Add(elem_ta.unstack(elem)); | |||||
| elems_ta = elems_ta_1.ToArray(); | |||||
| var i = constant_op.constant(0); | |||||
| var accs_ta = dtype_flat.Select(dt => new TensorArray(dtype: dt, | |||||
| size: ops.convert_to_tensor(n), | |||||
| dynamic_size: false, | |||||
| infer_shape: infer_shape)).ToArray(); | |||||
| /*Func<Tensor, TensorArray> compute = (i, tas) => | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| }; | |||||
| var r_a = control_flow_ops.while_loop( | |||||
| (i, _) => i < n, | |||||
| compute, | |||||
| new[] { i, accs_ta }, | |||||
| parallel_iterations: parallel_iterations, | |||||
| back_prop: back_prop, | |||||
| swap_memory: swap_memory, | |||||
| maximum_iterations: n);*/ | |||||
| }); | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | |||||
| } | |||||