diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs
index 8288c94c..a2c91983 100644
--- a/src/TensorFlowNET.Core/APIs/tf.array.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.array.cs
@@ -19,6 +19,7 @@ using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using static Tensorflow.Binding;
+using Tensorflow.Operations;
namespace Tensorflow
{
@@ -309,5 +310,27 @@ namespace Tensorflow
///
public Tensor stop_gradient(Tensor x, string name = null)
=> gen_array_ops.stop_gradient(x, name: name);
+
+ public TensorArray TensorArray(TF_DataType dtype, int size = 0, bool dynamic_size = false,
+ bool clear_after_read = true, Shape? element_shape = null, bool colocate_with_first_write_call = true,
+ bool infer_shape = true)
+ => tf.executing_eagerly() ?
+ new _EagerTensorArray(dtype, size: constant_op.constant(size), dynamic_size: dynamic_size,
+ clear_after_read: clear_after_read, element_shape: element_shape, infer_shape: infer_shape,
+ colocate_with_first_write_call: colocate_with_first_write_call) :
+ new _GraphTensorArray(dtype, size: constant_op.constant(size), dynamic_size: dynamic_size,
+ clear_after_read: clear_after_read, element_shape: element_shape, infer_shape: infer_shape,
+ colocate_with_first_write_call: colocate_with_first_write_call);
+
+ public TensorArray TensorArray(TF_DataType dtype, Tensor size, bool dynamic_size = false,
+ bool clear_after_read = true, Shape? element_shape = null, bool colocate_with_first_write_call = true,
+ bool infer_shape = true)
+ => tf.executing_eagerly() ?
+ new _EagerTensorArray(dtype, size: size, dynamic_size: dynamic_size,
+ clear_after_read: clear_after_read, element_shape: element_shape, infer_shape: infer_shape,
+ colocate_with_first_write_call: colocate_with_first_write_call) :
+ new _GraphTensorArray(dtype, size: size, dynamic_size: dynamic_size,
+ clear_after_read: clear_after_read, element_shape: element_shape, infer_shape: infer_shape,
+ colocate_with_first_write_call: colocate_with_first_write_call);
}
}
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
index 164facca..6b9f073c 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
@@ -294,10 +294,9 @@ namespace Tensorflow.Operations
Func _create_ta = (name, element_shape, dtype_) =>
{
- var ta = new TensorArray(dtype: dtype_,
+ var ta = tf.TensorArray(dtype: dtype_,
size: time_steps,
- element_shape: element_shape,
- tensor_array_name: base_name + name);
+ element_shape: element_shape);
return ta;
};
diff --git a/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs b/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs
new file mode 100644
index 00000000..cf1b50af
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs
@@ -0,0 +1,184 @@
+/*****************************************************************************
+ Copyright 2022 Haiping Chen. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using Tensorflow.Framework;
+using static Tensorflow.Binding;
+
+namespace Tensorflow.Operations
+{
+ public class _EagerTensorArray : TensorArray
+ {
+ TF_DataType _dtype;
+ public override TF_DataType dtype => _dtype;
+
+ ///
+ /// Used to keep track of what tensors the TensorArray should be
+ /// colocated with. We choose to colocate the TensorArray with the
+ /// first tensor written to it.
+ ///
+ bool _colocate_with_first_write_call;
+ public override bool colocate_with_first_write_call => _colocate_with_first_write_call;
+
+ bool _infer_shape;
+ public override bool infer_shape => _infer_shape;
+ public bool _dynamic_size;
+ public Shape _element_shape;
+
+ public List _colocate_with;
+
+ Tensor _handle;
+ public override Tensor handle => _handle;
+ Tensor _flow;
+ public override Tensor flow => _flow;
+ bool _clear_after_read;
+ List _tensor_array;
+
+ public _EagerTensorArray(TF_DataType dtype, Tensor size, bool dynamic_size = false,
+ bool clear_after_read = true, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
+ bool infer_shape = true, Shape? element_shape = null,
+ bool colocate_with_first_write_call = true, string name = null)
+ {
+ _flow = constant_op.constant(0);
+ _infer_shape = infer_shape;
+ _element_shape = element_shape ?? Shape.Null;
+ _colocate_with_first_write_call = colocate_with_first_write_call;
+ _dtype = dtype.as_base_dtype();
+ _dynamic_size = dynamic_size;
+ _clear_after_read = clear_after_read;
+ _tensor_array = new List();
+ }
+
+ public override TensorArray unstack(Tensor value, string name = null)
+ {
+ return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate
+ {
+ var num_elements = array_ops.shape(value)[0];
+ return scatter(indices: math_ops.range(0, num_elements), value: value, name: name);
+ });
+ }
+
+ public TensorArray scatter(Tensor indices, Tensor value, string name = null)
+ {
+ /*return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _handle, value, indices }), delegate
+ {
+ value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
+ if (_infer_shape)
+ {
+ var shape = new Shape(value.shape.dims.Skip(1).ToArray());
+ _merge_element_shape(shape);
+ }
+
+ _maybe_colocate_with(value);
+ var flow_out = gen_data_flow_ops.tensor_array_scatter_v3(
+ handle: _handle,
+ indices: indices,
+ value: value,
+ flow_in: _flow,
+ name: name);
+
+ var ta = new _EagerTensorArray(_dtype,
+ infer_shape: _infer_shape,
+ element_shape: _element_shape[0],
+ dynamic_size: _dynamic_size,
+ handle: _handle,
+ flow: flow_out,
+ colocate_with_first_write_call: _colocate_with_first_write_call);
+
+
+ return ta;
+ });*/
+ throw new NotImplementedException("");
+ }
+
+ public void _merge_element_shape(Shape shape)
+ {
+ _element_shape.concatenate(shape);
+ }
+
+ public void _maybe_colocate_with(Tensor value)
+ {
+ _colocate_with.Add(value);
+ }
+
+ public override Tensor read(T index, string name = null)
+ {
+ int index_int = -1;
+ if (index is int int_index)
+ index_int = int_index;
+ else if (index is Tensor tensor_index)
+ index_int = tensor_index.numpy();
+ else
+ throw new ValueError("");
+
+ if (_clear_after_read)
+ {
+ _tensor_array[index_int] = null;
+ }
+
+ return _tensor_array[index_int];
+ }
+
+ public override TensorArray write(Tensor index, Tensor value, string name = null)
+ {
+ if (_infer_shape)
+ _element_shape = _element_shape.merge_with(value.shape);
+ _tensor_array.add(value);
+ return this;
+ }
+
+ public override TensorArray write(int index, T value, string name = null)
+ {
+ var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
+ var index_tensor = ops.convert_to_tensor(index, name: "index");
+ return write(index_tensor, value_tensor, name: name);
+ }
+
+ private Tensor size(string name = null)
+ {
+ return gen_data_flow_ops.tensor_array_size_v3(_handle, _flow, name: name);
+ }
+
+ public override Tensor stack(string name = null)
+ {
+ ops.colocate_with(_handle);
+ return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate
+ {
+ return gather(math_ops.range(0, size()), name: name);
+ });
+ }
+
+ public override Tensor gather(Tensor indices, string name = null)
+ {
+ var element_shape = Shape.Null;
+
+ var value = gen_data_flow_ops.tensor_array_gather_v3(
+ handle: _handle,
+ indices: indices,
+ flow_in: _flow,
+ dtype: _dtype,
+ name: name,
+ element_shape: element_shape);
+
+ //if (element_shape != null)
+ //value.set_shape(-1, element_shape.dims);
+
+ return value;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs
index 2c6527d6..16870e9f 100644
--- a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs
+++ b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs
@@ -21,7 +21,7 @@ using static Tensorflow.Binding;
namespace Tensorflow.Operations
{
- public class _GraphTensorArray
+ public class _GraphTensorArray : TensorArray
{
internal TF_DataType _dtype;
public TF_DataType dtype => _dtype;
@@ -47,7 +47,7 @@ namespace Tensorflow.Operations
public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null,
bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
- bool infer_shape = true, Shape element_shape = null,
+ bool infer_shape = true, Shape? element_shape = null,
bool colocate_with_first_write_call = true, string name = null)
{
clear_after_read = clear_after_read ?? true;
@@ -108,7 +108,7 @@ namespace Tensorflow.Operations
});
}
- public TensorArray unstack(Tensor value, string name = null)
+ public override TensorArray unstack(Tensor value, string name = null)
{
return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate
{
@@ -119,7 +119,7 @@ namespace Tensorflow.Operations
public TensorArray scatter(Tensor indices, Tensor value, string name = null)
{
- return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _handle, value, indices }), delegate
+ /*return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _handle, value, indices }), delegate
{
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
if (_infer_shape)
@@ -136,7 +136,7 @@ namespace Tensorflow.Operations
flow_in: _flow,
name: name);
- var ta = new TensorArray(_dtype,
+ var ta = new _GraphTensorArray(_dtype,
infer_shape: _infer_shape,
element_shape: _element_shape[0],
dynamic_size: _dynamic_size,
@@ -144,9 +144,9 @@ namespace Tensorflow.Operations
flow: flow_out,
colocate_with_first_write_call: _colocate_with_first_write_call);
-
return ta;
- });
+ });*/
+ throw new NotImplementedException("");
}
public void _merge_element_shape(Shape shape)
@@ -159,11 +159,11 @@ namespace Tensorflow.Operations
_colocate_with.Add(value);
}
- public Tensor read(Tensor index, string name = null)
+ public override Tensor read(T index, string name = null)
{
var value = gen_data_flow_ops.tensor_array_read_v3(
handle: _handle,
- index: index,
+ index: constant_op.constant(index),
flow_in: _flow,
dtype: _dtype,
name: name);
@@ -174,11 +174,10 @@ namespace Tensorflow.Operations
return value;
}
- public TensorArray write(Tensor index, Tensor value, string name = null)
+ public override TensorArray write(Tensor index, Tensor value, string name = null)
{
return tf_with(ops.name_scope(name, "TensorArrayWrite", new { _handle, index, value }), delegate
{
- value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
_maybe_colocate_with(value);
var flow_out = gen_data_flow_ops.tensor_array_write_v3(
handle: _handle,
@@ -191,12 +190,19 @@ namespace Tensorflow.Operations
});
}
+ public override TensorArray write(int index, T value, string name = null)
+ {
+ var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
+ var index_tensor = ops.convert_to_tensor(index, name: "index");
+ return write(index_tensor, value_tensor);
+ }
+
private Tensor size(string name = null)
{
return gen_data_flow_ops.tensor_array_size_v3(_handle, _flow, name: name);
}
- public Tensor stack(string name = null)
+ public override Tensor stack(string name = null)
{
ops.colocate_with(_handle);
return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate
@@ -205,7 +211,7 @@ namespace Tensorflow.Operations
});
}
- public Tensor gather(Tensor indices, string name = null)
+ public override Tensor gather(Tensor indices, string name = null)
{
var element_shape = Shape.Null;
diff --git a/src/TensorFlowNET.Core/Operations/functional_ops.cs b/src/TensorFlowNET.Core/Operations/functional_ops.cs
index 89b2ce40..908029f5 100644
--- a/src/TensorFlowNET.Core/Operations/functional_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/functional_ops.cs
@@ -87,9 +87,9 @@ namespace Tensorflow
// n = array_ops.shape(elems_flat[0])[0];
//}
- var elems_ta = elems_flat.Select(elem => new TensorArray(
+ var elems_ta = elems_flat.Select(elem => tf.TensorArray(
elem.dtype,
- size: tf.constant(n),
+ size: n,
dynamic_size: false,
element_shape: elem.shape.dims.Skip(1).ToArray(),
infer_shape: true)).ToList();
@@ -113,9 +113,9 @@ namespace Tensorflow
i = 0;
}
- var accs_ta = a_flat.Select(init => new TensorArray(
+ var accs_ta = a_flat.Select(init => tf.TensorArray(
dtype: init.dtype,
- size: tf.constant(n),
+ size: n,
element_shape: infer_shape ? init.shape : null,
dynamic_size: false,
infer_shape: infer_shape)).ToArray();
@@ -124,7 +124,7 @@ namespace Tensorflow
{
for (int index = 0; index < accs_ta.Length; index++)
{
- accs_ta[index].write(tf.constant(reverse ? n - 1 : 0), a_flat[index]);
+ accs_ta[index].write(reverse ? n - 1 : 0, a_flat[index]);
}
}
diff --git a/src/TensorFlowNET.Core/Operations/map_fn.cs b/src/TensorFlowNET.Core/Operations/map_fn.cs
index 1803ac55..a754f230 100644
--- a/src/TensorFlowNET.Core/Operations/map_fn.cs
+++ b/src/TensorFlowNET.Core/Operations/map_fn.cs
@@ -78,8 +78,8 @@ namespace Tensorflow
var n = static_shape[0];
// TensorArrays are always flat
- var elems_ta = elems_flat.Select(elem => new TensorArray(dtype: elem.dtype,
- size: ops.convert_to_tensor(n),
+ var elems_ta = elems_flat.Select(elem => tf.TensorArray(dtype: elem.dtype,
+ size: Convert.ToInt32(n),
dynamic_size: false,
infer_shape: true)).ToArray();
@@ -92,8 +92,8 @@ namespace Tensorflow
var i = constant_op.constant(0);
- var accs_ta = dtype_flat.Select(dt => new TensorArray(dtype: dt,
- size: ops.convert_to_tensor(n),
+ var accs_ta = dtype_flat.Select(dt => tf.TensorArray(dtype: dt,
+ size: Convert.ToInt32(n),
dynamic_size: false,
infer_shape: infer_shape)).ToArray();
diff --git a/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs b/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs
index dc510a41..7d2da544 100644
--- a/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs
@@ -1,4 +1,5 @@
using Tensorflow.Operations;
+using static Tensorflow.Binding;
namespace Tensorflow
{
@@ -12,37 +13,21 @@ namespace Tensorflow
///
public static TensorArray build_ta_with_new_flow(TensorArray old_ta, Tensor flow)
{
- var impl = old_ta._implementation;
+ var new_ta = tf.TensorArray(
+ dtype: old_ta.dtype,
+ infer_shape: old_ta.infer_shape,
+ colocate_with_first_write_call: old_ta.colocate_with_first_write_call);
- var new_ta = new TensorArray(
- dtype: impl.dtype,
- handle: impl.handle,
- flow: flow,
- infer_shape: impl.infer_shape,
- colocate_with_first_write_call: impl.colocate_with_first_write_call);
-
- var new_impl = new_ta._implementation;
- new_impl._dynamic_size = impl._dynamic_size;
- new_impl._colocate_with = impl._colocate_with;
- new_impl._element_shape = impl._element_shape;
return new_ta;
}
public static TensorArray build_ta_with_new_flow(_GraphTensorArray old_ta, Tensor flow)
{
- var impl = old_ta;
-
- var new_ta = new TensorArray(
- dtype: impl.dtype,
- handle: impl.handle,
- flow: flow,
- infer_shape: impl.infer_shape,
- colocate_with_first_write_call: impl.colocate_with_first_write_call);
+ var new_ta = tf.TensorArray(
+ dtype: old_ta.dtype,
+ infer_shape: old_ta.infer_shape,
+ colocate_with_first_write_call: old_ta.colocate_with_first_write_call);
- var new_impl = new_ta._implementation;
- new_impl._dynamic_size = impl._dynamic_size;
- new_impl._colocate_with = impl._colocate_with;
- new_impl._element_shape = impl._element_shape;
return new_ta;
}
}
diff --git a/src/TensorFlowNET.Core/Tensors/TensorArray.cs b/src/TensorFlowNET.Core/Tensors/TensorArray.cs
index 52b364b7..fb59593c 100644
--- a/src/TensorFlowNET.Core/Tensors/TensorArray.cs
+++ b/src/TensorFlowNET.Core/Tensors/TensorArray.cs
@@ -27,42 +27,22 @@ namespace Tensorflow
/// `while_loop` and `map_fn`. It supports gradient back-propagation via special
/// "flow" control flow dependencies.
///
- public class TensorArray : ITensorOrTensorArray
+ public abstract class TensorArray : ITensorOrTensorArray
{
- internal _GraphTensorArray _implementation;
+ public virtual TF_DataType dtype { get; }
+ public virtual Tensor handle { get; }
+ public virtual Tensor flow { get; }
+ public virtual bool infer_shape { get; }
+ public virtual bool colocate_with_first_write_call { get; }
- public TF_DataType dtype => _implementation._dtype;
- public Tensor handle => _implementation._handle;
- public Tensor flow => _implementation._flow;
+ public abstract TensorArray unstack(Tensor value, string name = null);
- public TensorArray(TF_DataType dtype, Tensor size = default, bool? clear_after_read = null, bool? dynamic_size = null,
- string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
- bool infer_shape = true, Shape element_shape = null,
- bool colocate_with_first_write_call = true, string name = null)
- {
- _implementation = new _GraphTensorArray(dtype,
- size: size,
- dynamic_size: dynamic_size,
- clear_after_read: clear_after_read,
- tensor_array_name: tensor_array_name,
- handle: handle,
- flow: flow,
- infer_shape: infer_shape,
- element_shape: element_shape,
- colocate_with_first_write_call: colocate_with_first_write_call,
- name: name);
- }
+ public abstract Tensor read(T index, string name = null);
- public TensorArray unstack(Tensor value, string name = null)
- => _implementation.unstack(value, name: name);
+ public abstract TensorArray write(int index, T value, string name = null);
+ public abstract TensorArray write(Tensor index, Tensor value, string name = null);
- public Tensor read(Tensor index, string name = null)
- => _implementation.read(index, name: name);
-
- public TensorArray write(Tensor index, Tensor value, string name = null)
- => _implementation.write(index, value, name: name);
-
- public Tensor stack(string name = null)
- => _implementation.stack(name: name);
+ public abstract Tensor stack(string name = null);
+ public abstract Tensor gather(Tensor indices, string name = null);
}
}
diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs
index b2d3d891..6a12ed20 100644
--- a/test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs
+++ b/test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs
@@ -77,5 +77,20 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
var r3 = tf.gather(p2, i2, axis: 1);
Assert.AreEqual(new Shape(4,1,2), r3.shape);
}
+
+ ///
+ /// https://www.tensorflow.org/api_docs/python/tf/TensorArray
+ ///
+ [TestMethod]
+ public void TensorArray()
+ {
+ var ta = tf.TensorArray(tf.float32, size: 0, dynamic_size: true, clear_after_read: false);
+ ta.write(0, 10);
+ ta.write(1, 20);
+ ta.write(2, 30);
+ Assert.AreEqual(ta.read(0).numpy(), 10f);
+ Assert.AreEqual(ta.read(1).numpy(), 20f);
+ Assert.AreEqual(ta.read(2).numpy(), 30f);
+ }
}
}