diff --git a/src/TensorFlowNET.Core/APIs/tf.ops.cs b/src/TensorFlowNET.Core/APIs/tf.ops.cs
index cd7a3642..c91f283d 100644
--- a/src/TensorFlowNET.Core/APIs/tf.ops.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.ops.cs
@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/
+using System;
using System.Collections.Generic;
namespace Tensorflow
@@ -61,5 +62,34 @@ namespace Tensorflow
///
public Tensor no_op(string name = null)
=> gen_control_flow_ops.no_op(name: name);
+
+ ///
+ /// map on the list of tensors unpacked from `elems` on dimension 0.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ /// A tensor or (possibly nested) sequence of tensors.
+ public Tensor map_fn(Func fn,
+ Tensor elems,
+ TF_DataType dtype = TF_DataType.DtInvalid,
+ int parallel_iterations = -1,
+ bool back_prop = true,
+ bool swap_memory = false,
+ bool infer_shape = true,
+ string name = null)
+ => Operation.map_fn(fn,
+ elems,
+ dtype,
+ parallel_iterations: parallel_iterations,
+ back_prop: back_prop,
+ swap_memory: swap_memory,
+ infer_shape: infer_shape,
+ name: name);
}
}
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
index 1b68d1cd..8e7425e5 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
@@ -145,7 +145,7 @@ namespace Tensorflow.Operations
{
var ta = new TensorArray(dtype: dtype_,
size: time_steps,
- element_shape: element_shape,
+ element_shape: new[] { element_shape },
tensor_array_name: base_name + name);
return ta;
};
diff --git a/src/TensorFlowNET.Core/Operations/TensorArray.cs b/src/TensorFlowNET.Core/Operations/TensorArray.cs
index 858dac47..7251bf85 100644
--- a/src/TensorFlowNET.Core/Operations/TensorArray.cs
+++ b/src/TensorFlowNET.Core/Operations/TensorArray.cs
@@ -33,9 +33,13 @@ namespace Tensorflow.Operations
{
_GraphTensorArray _implementation;
- public TensorArray(TF_DataType dtype, Tensor size = null, bool? clear_after_read = null, bool? dynamic_size = null,
+ public TF_DataType dtype => _implementation._dtype;
+ public Tensor handle => _implementation._handle;
+ public Tensor flow => _implementation._flow;
+
+ public TensorArray(TF_DataType dtype, Tensor size = default, bool? clear_after_read = null, bool? dynamic_size = null,
string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
- bool infer_shape = true, TensorShape element_shape = null,
+ bool infer_shape = true, TensorShape[] element_shape = null,
bool colocate_with_first_write_call = true, string name = null)
{
_implementation = new _GraphTensorArray(dtype,
@@ -50,5 +54,8 @@ namespace Tensorflow.Operations
colocate_with_first_write_call: colocate_with_first_write_call,
name: name);
}
+
+ public TensorArray unstack(Tensor value, string name = null)
+ => _implementation.unstack(value, name: name);
}
}
diff --git a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs
index 4c700a5f..bd919ad8 100644
--- a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs
+++ b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs
@@ -16,6 +16,7 @@
using System;
using System.Collections.Generic;
+using System.Linq;
using System.Text;
using static Tensorflow.Binding;
@@ -23,7 +24,7 @@ namespace Tensorflow.Operations
{
internal class _GraphTensorArray
{
- TF_DataType _dtype;
+ internal TF_DataType _dtype;
///
/// Used to keep track of what tensors the TensorArray should be
@@ -33,23 +34,27 @@ namespace Tensorflow.Operations
bool _colocate_with_first_write_call;
bool _infer_shape;
+ bool _dynamic_size;
List _element_shape;
- object _colocate_with;
+ List _colocate_with;
+
+ internal Tensor _handle;
+ internal Tensor _flow;
public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null,
bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
- bool infer_shape = true, TensorShape element_shape = null,
+ bool infer_shape = true, TensorShape[] element_shape = null,
bool colocate_with_first_write_call = true, string name = null)
{
clear_after_read = clear_after_read ?? true;
dynamic_size = dynamic_size ?? false;
-
+ _dynamic_size = dynamic_size.Value;
_dtype = dtype;
_colocate_with_first_write_call = colocate_with_first_write_call;
if (colocate_with_first_write_call)
- _colocate_with = new Tensor[0];
+ _colocate_with = new List();
// Record the current static shape for the array elements. The element
// shape is defined either by `element_shape` or the shape of the tensor
@@ -66,11 +71,12 @@ namespace Tensorflow.Operations
_element_shape = new List { };
}
- tf_with(ops.name_scope(name, "", new { handle, size, flow }), scope =>
+ tf_with(ops.name_scope(name, "TensorArray", new { handle, size, flow }), scope =>
{
if(handle != null)
{
-
+ _handle = handle;
+ _flow = flow;
}
else
{
@@ -89,14 +95,65 @@ namespace Tensorflow.Operations
if (colocate_with_first_write_call)
{
ops.colocate_with(ignore_existing: true);
- create();
+ (_handle, _flow) = create();
}
else
{
-
+ (_handle, _flow) = create();
}
}
});
}
+
+ public TensorArray unstack(Tensor value, string name = null)
+ {
+ return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate
+ {
+ var num_elements = array_ops.shape(value)[0];
+ return scatter(indices: math_ops.range(0, num_elements), value: value, name: name);
+ });
+ }
+
+ public TensorArray scatter(Tensor indices, Tensor value, string name = null)
+ {
+ return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _handle, value, indices }), delegate
+ {
+ value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
+ if (_infer_shape)
+ {
+ var shape = new TensorShape(value.TensorShape.dims.Skip(1).ToArray());
+ _merge_element_shape(shape);
+ }
+
+ _maybe_colocate_with(value);
+ var flow_out = gen_data_flow_ops.tensor_array_scatter_v3(
+ handle: _handle,
+ indices: indices,
+ value: value,
+ flow_in: _flow,
+ name: name);
+
+ var ta = new TensorArray(_dtype,
+ infer_shape:_infer_shape,
+ element_shape: _element_shape.ToArray(),
+ dynamic_size: _dynamic_size,
+ handle: _handle,
+ flow: flow_out,
+ colocate_with_first_write_call: _colocate_with_first_write_call);
+
+
+ return ta;
+ });
+ }
+
+ public void _merge_element_shape(TensorShape shape)
+ {
+ _element_shape.Add(shape);
+ }
+
+ public void _maybe_colocate_with(Tensor value)
+ {
+ _colocate_with.Add(value);
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs
index 4e5bd1f6..fa194934 100644
--- a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs
@@ -27,10 +27,13 @@ namespace Tensorflow
return _op.output;
}
- public static (Tensor, Tensor) tensor_array_v3(Tensor size, TF_DataType dtype = TF_DataType.DtInvalid,
- int[] element_shape = null, bool dynamic_size = false, bool clear_after_read = true,
+ public static (Tensor, Tensor) tensor_array_v3(T size, TF_DataType dtype = TF_DataType.DtInvalid,
+ TensorShape[] element_shape = null, bool dynamic_size = false, bool clear_after_read = true,
bool identical_element_shapes = false, string tensor_array_name = "tensor_array_name", string name = null)
{
+ if (tensor_array_name == null)
+ tensor_array_name = string.Empty;
+
var _op = _op_def_lib._apply_op_helper("TensorArrayV3", name, new
{
size,
@@ -42,7 +45,21 @@ namespace Tensorflow
tensor_array_name
});
- return (null, null);
+ return (_op.outputs[0], _op.outputs[1]);
+ }
+
+ public static Tensor tensor_array_scatter_v3(Tensor handle, Tensor indices, Tensor value,
+ Tensor flow_in, string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("TensorArrayScatterV3", name, new
+ {
+ handle,
+ indices,
+ value,
+ flow_in
+ });
+
+ return _op.output;
}
public static Tensor padding_fifo_queue_v2(TF_DataType[] component_types, TensorShape[] shapes,
diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs
index 226a4839..2086bb36 100644
--- a/test/TensorFlowNET.UnitTest/OperationsTest.cs
+++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs
@@ -1494,5 +1494,22 @@ namespace TensorFlowNET.UnitTest
}
#endregion
}
+
+ [TestMethod]
+ public void map_fn()
+ {
+ var a = tf.constant(new[] { 1, 2, 3, 4 });
+ var b = tf.constant(new[] { 17, 12, 11, 10 });
+ var ab = tf.stack(new[] { a, b }, 1);
+
+ Func map_operation = (value_ab) =>
+ {
+ var value_a = value_ab[0];
+ var value_b = value_ab[1];
+ return value_a + value_b;
+ };
+
+ var map_result = tf.map_fn(map_operation, ab);
+ }
}
}