diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs index 8288c94c..a2c91983 100644 --- a/src/TensorFlowNET.Core/APIs/tf.array.cs +++ b/src/TensorFlowNET.Core/APIs/tf.array.cs @@ -19,6 +19,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; using static Tensorflow.Binding; +using Tensorflow.Operations; namespace Tensorflow { @@ -309,5 +310,27 @@ namespace Tensorflow /// public Tensor stop_gradient(Tensor x, string name = null) => gen_array_ops.stop_gradient(x, name: name); + + public TensorArray TensorArray(TF_DataType dtype, int size = 0, bool dynamic_size = false, + bool clear_after_read = true, Shape? element_shape = null, bool colocate_with_first_write_call = true, + bool infer_shape = true) + => tf.executing_eagerly() ? + new _EagerTensorArray(dtype, size: constant_op.constant(size), dynamic_size: dynamic_size, + clear_after_read: clear_after_read, element_shape: element_shape, infer_shape: infer_shape, + colocate_with_first_write_call: colocate_with_first_write_call) : + new _GraphTensorArray(dtype, size: constant_op.constant(size), dynamic_size: dynamic_size, + clear_after_read: clear_after_read, element_shape: element_shape, infer_shape: infer_shape, + colocate_with_first_write_call: colocate_with_first_write_call); + + public TensorArray TensorArray(TF_DataType dtype, Tensor size, bool dynamic_size = false, + bool clear_after_read = true, Shape? element_shape = null, bool colocate_with_first_write_call = true, + bool infer_shape = true) + => tf.executing_eagerly() ? + new _EagerTensorArray(dtype, size: size, dynamic_size: dynamic_size, + clear_after_read: clear_after_read, element_shape: element_shape, infer_shape: infer_shape, + colocate_with_first_write_call: colocate_with_first_write_call) : + new _GraphTensorArray(dtype, size: size, dynamic_size: dynamic_size, + clear_after_read: clear_after_read, element_shape: element_shape, infer_shape: infer_shape, + colocate_with_first_write_call: colocate_with_first_write_call); } } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs index 164facca..6b9f073c 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs @@ -294,10 +294,9 @@ namespace Tensorflow.Operations Func _create_ta = (name, element_shape, dtype_) => { - var ta = new TensorArray(dtype: dtype_, + var ta = tf.TensorArray(dtype: dtype_, size: time_steps, - element_shape: element_shape, - tensor_array_name: base_name + name); + element_shape: element_shape); return ta; }; diff --git a/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs b/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs new file mode 100644 index 00000000..cf1b50af --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs @@ -0,0 +1,184 @@ +/***************************************************************************** + Copyright 2022 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Framework; +using static Tensorflow.Binding; + +namespace Tensorflow.Operations +{ + public class _EagerTensorArray : TensorArray + { + TF_DataType _dtype; + public override TF_DataType dtype => _dtype; + + /// + /// Used to keep track of what tensors the TensorArray should be + /// colocated with. We choose to colocate the TensorArray with the + /// first tensor written to it. + /// + bool _colocate_with_first_write_call; + public override bool colocate_with_first_write_call => _colocate_with_first_write_call; + + bool _infer_shape; + public override bool infer_shape => _infer_shape; + public bool _dynamic_size; + public Shape _element_shape; + + public List _colocate_with; + + Tensor _handle; + public override Tensor handle => _handle; + Tensor _flow; + public override Tensor flow => _flow; + bool _clear_after_read; + List _tensor_array; + + public _EagerTensorArray(TF_DataType dtype, Tensor size, bool dynamic_size = false, + bool clear_after_read = true, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, + bool infer_shape = true, Shape? element_shape = null, + bool colocate_with_first_write_call = true, string name = null) + { + _flow = constant_op.constant(0); + _infer_shape = infer_shape; + _element_shape = element_shape ?? Shape.Null; + _colocate_with_first_write_call = colocate_with_first_write_call; + _dtype = dtype.as_base_dtype(); + _dynamic_size = dynamic_size; + _clear_after_read = clear_after_read; + _tensor_array = new List(); + } + + public override TensorArray unstack(Tensor value, string name = null) + { + return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate + { + var num_elements = array_ops.shape(value)[0]; + return scatter(indices: math_ops.range(0, num_elements), value: value, name: name); + }); + } + + public TensorArray scatter(Tensor indices, Tensor value, string name = null) + { + /*return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _handle, value, indices }), delegate + { + value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); + if (_infer_shape) + { + var shape = new Shape(value.shape.dims.Skip(1).ToArray()); + _merge_element_shape(shape); + } + + _maybe_colocate_with(value); + var flow_out = gen_data_flow_ops.tensor_array_scatter_v3( + handle: _handle, + indices: indices, + value: value, + flow_in: _flow, + name: name); + + var ta = new _EagerTensorArray(_dtype, + infer_shape: _infer_shape, + element_shape: _element_shape[0], + dynamic_size: _dynamic_size, + handle: _handle, + flow: flow_out, + colocate_with_first_write_call: _colocate_with_first_write_call); + + + return ta; + });*/ + throw new NotImplementedException(""); + } + + public void _merge_element_shape(Shape shape) + { + _element_shape.concatenate(shape); + } + + public void _maybe_colocate_with(Tensor value) + { + _colocate_with.Add(value); + } + + public override Tensor read(T index, string name = null) + { + int index_int = -1; + if (index is int int_index) + index_int = int_index; + else if (index is Tensor tensor_index) + index_int = tensor_index.numpy(); + else + throw new ValueError(""); + + if (_clear_after_read) + { + _tensor_array[index_int] = null; + } + + return _tensor_array[index_int]; + } + + public override TensorArray write(Tensor index, Tensor value, string name = null) + { + if (_infer_shape) + _element_shape = _element_shape.merge_with(value.shape); + _tensor_array.add(value); + return this; + } + + public override TensorArray write(int index, T value, string name = null) + { + var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); + var index_tensor = ops.convert_to_tensor(index, name: "index"); + return write(index_tensor, value_tensor, name: name); + } + + private Tensor size(string name = null) + { + return gen_data_flow_ops.tensor_array_size_v3(_handle, _flow, name: name); + } + + public override Tensor stack(string name = null) + { + ops.colocate_with(_handle); + return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate + { + return gather(math_ops.range(0, size()), name: name); + }); + } + + public override Tensor gather(Tensor indices, string name = null) + { + var element_shape = Shape.Null; + + var value = gen_data_flow_ops.tensor_array_gather_v3( + handle: _handle, + indices: indices, + flow_in: _flow, + dtype: _dtype, + name: name, + element_shape: element_shape); + + //if (element_shape != null) + //value.set_shape(-1, element_shape.dims); + + return value; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs index 2c6527d6..16870e9f 100644 --- a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs +++ b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs @@ -21,7 +21,7 @@ using static Tensorflow.Binding; namespace Tensorflow.Operations { - public class _GraphTensorArray + public class _GraphTensorArray : TensorArray { internal TF_DataType _dtype; public TF_DataType dtype => _dtype; @@ -47,7 +47,7 @@ namespace Tensorflow.Operations public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null, bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, - bool infer_shape = true, Shape element_shape = null, + bool infer_shape = true, Shape? element_shape = null, bool colocate_with_first_write_call = true, string name = null) { clear_after_read = clear_after_read ?? true; @@ -108,7 +108,7 @@ namespace Tensorflow.Operations }); } - public TensorArray unstack(Tensor value, string name = null) + public override TensorArray unstack(Tensor value, string name = null) { return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate { @@ -119,7 +119,7 @@ namespace Tensorflow.Operations public TensorArray scatter(Tensor indices, Tensor value, string name = null) { - return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _handle, value, indices }), delegate + /*return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _handle, value, indices }), delegate { value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); if (_infer_shape) @@ -136,7 +136,7 @@ namespace Tensorflow.Operations flow_in: _flow, name: name); - var ta = new TensorArray(_dtype, + var ta = new _GraphTensorArray(_dtype, infer_shape: _infer_shape, element_shape: _element_shape[0], dynamic_size: _dynamic_size, @@ -144,9 +144,9 @@ namespace Tensorflow.Operations flow: flow_out, colocate_with_first_write_call: _colocate_with_first_write_call); - return ta; - }); + });*/ + throw new NotImplementedException(""); } public void _merge_element_shape(Shape shape) @@ -159,11 +159,11 @@ namespace Tensorflow.Operations _colocate_with.Add(value); } - public Tensor read(Tensor index, string name = null) + public override Tensor read(T index, string name = null) { var value = gen_data_flow_ops.tensor_array_read_v3( handle: _handle, - index: index, + index: constant_op.constant(index), flow_in: _flow, dtype: _dtype, name: name); @@ -174,11 +174,10 @@ namespace Tensorflow.Operations return value; } - public TensorArray write(Tensor index, Tensor value, string name = null) + public override TensorArray write(Tensor index, Tensor value, string name = null) { return tf_with(ops.name_scope(name, "TensorArrayWrite", new { _handle, index, value }), delegate { - value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); _maybe_colocate_with(value); var flow_out = gen_data_flow_ops.tensor_array_write_v3( handle: _handle, @@ -191,12 +190,19 @@ namespace Tensorflow.Operations }); } + public override TensorArray write(int index, T value, string name = null) + { + var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); + var index_tensor = ops.convert_to_tensor(index, name: "index"); + return write(index_tensor, value_tensor); + } + private Tensor size(string name = null) { return gen_data_flow_ops.tensor_array_size_v3(_handle, _flow, name: name); } - public Tensor stack(string name = null) + public override Tensor stack(string name = null) { ops.colocate_with(_handle); return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate @@ -205,7 +211,7 @@ namespace Tensorflow.Operations }); } - public Tensor gather(Tensor indices, string name = null) + public override Tensor gather(Tensor indices, string name = null) { var element_shape = Shape.Null; diff --git a/src/TensorFlowNET.Core/Operations/functional_ops.cs b/src/TensorFlowNET.Core/Operations/functional_ops.cs index 89b2ce40..908029f5 100644 --- a/src/TensorFlowNET.Core/Operations/functional_ops.cs +++ b/src/TensorFlowNET.Core/Operations/functional_ops.cs @@ -87,9 +87,9 @@ namespace Tensorflow // n = array_ops.shape(elems_flat[0])[0]; //} - var elems_ta = elems_flat.Select(elem => new TensorArray( + var elems_ta = elems_flat.Select(elem => tf.TensorArray( elem.dtype, - size: tf.constant(n), + size: n, dynamic_size: false, element_shape: elem.shape.dims.Skip(1).ToArray(), infer_shape: true)).ToList(); @@ -113,9 +113,9 @@ namespace Tensorflow i = 0; } - var accs_ta = a_flat.Select(init => new TensorArray( + var accs_ta = a_flat.Select(init => tf.TensorArray( dtype: init.dtype, - size: tf.constant(n), + size: n, element_shape: infer_shape ? init.shape : null, dynamic_size: false, infer_shape: infer_shape)).ToArray(); @@ -124,7 +124,7 @@ namespace Tensorflow { for (int index = 0; index < accs_ta.Length; index++) { - accs_ta[index].write(tf.constant(reverse ? n - 1 : 0), a_flat[index]); + accs_ta[index].write(reverse ? n - 1 : 0, a_flat[index]); } } diff --git a/src/TensorFlowNET.Core/Operations/map_fn.cs b/src/TensorFlowNET.Core/Operations/map_fn.cs index 1803ac55..a754f230 100644 --- a/src/TensorFlowNET.Core/Operations/map_fn.cs +++ b/src/TensorFlowNET.Core/Operations/map_fn.cs @@ -78,8 +78,8 @@ namespace Tensorflow var n = static_shape[0]; // TensorArrays are always flat - var elems_ta = elems_flat.Select(elem => new TensorArray(dtype: elem.dtype, - size: ops.convert_to_tensor(n), + var elems_ta = elems_flat.Select(elem => tf.TensorArray(dtype: elem.dtype, + size: Convert.ToInt32(n), dynamic_size: false, infer_shape: true)).ToArray(); @@ -92,8 +92,8 @@ namespace Tensorflow var i = constant_op.constant(0); - var accs_ta = dtype_flat.Select(dt => new TensorArray(dtype: dt, - size: ops.convert_to_tensor(n), + var accs_ta = dtype_flat.Select(dt => tf.TensorArray(dtype: dt, + size: Convert.ToInt32(n), dynamic_size: false, infer_shape: infer_shape)).ToArray(); diff --git a/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs b/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs index dc510a41..7d2da544 100644 --- a/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs @@ -1,4 +1,5 @@ using Tensorflow.Operations; +using static Tensorflow.Binding; namespace Tensorflow { @@ -12,37 +13,21 @@ namespace Tensorflow /// public static TensorArray build_ta_with_new_flow(TensorArray old_ta, Tensor flow) { - var impl = old_ta._implementation; + var new_ta = tf.TensorArray( + dtype: old_ta.dtype, + infer_shape: old_ta.infer_shape, + colocate_with_first_write_call: old_ta.colocate_with_first_write_call); - var new_ta = new TensorArray( - dtype: impl.dtype, - handle: impl.handle, - flow: flow, - infer_shape: impl.infer_shape, - colocate_with_first_write_call: impl.colocate_with_first_write_call); - - var new_impl = new_ta._implementation; - new_impl._dynamic_size = impl._dynamic_size; - new_impl._colocate_with = impl._colocate_with; - new_impl._element_shape = impl._element_shape; return new_ta; } public static TensorArray build_ta_with_new_flow(_GraphTensorArray old_ta, Tensor flow) { - var impl = old_ta; - - var new_ta = new TensorArray( - dtype: impl.dtype, - handle: impl.handle, - flow: flow, - infer_shape: impl.infer_shape, - colocate_with_first_write_call: impl.colocate_with_first_write_call); + var new_ta = tf.TensorArray( + dtype: old_ta.dtype, + infer_shape: old_ta.infer_shape, + colocate_with_first_write_call: old_ta.colocate_with_first_write_call); - var new_impl = new_ta._implementation; - new_impl._dynamic_size = impl._dynamic_size; - new_impl._colocate_with = impl._colocate_with; - new_impl._element_shape = impl._element_shape; return new_ta; } } diff --git a/src/TensorFlowNET.Core/Tensors/TensorArray.cs b/src/TensorFlowNET.Core/Tensors/TensorArray.cs index 52b364b7..fb59593c 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorArray.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorArray.cs @@ -27,42 +27,22 @@ namespace Tensorflow /// `while_loop` and `map_fn`. It supports gradient back-propagation via special /// "flow" control flow dependencies. /// - public class TensorArray : ITensorOrTensorArray + public abstract class TensorArray : ITensorOrTensorArray { - internal _GraphTensorArray _implementation; + public virtual TF_DataType dtype { get; } + public virtual Tensor handle { get; } + public virtual Tensor flow { get; } + public virtual bool infer_shape { get; } + public virtual bool colocate_with_first_write_call { get; } - public TF_DataType dtype => _implementation._dtype; - public Tensor handle => _implementation._handle; - public Tensor flow => _implementation._flow; + public abstract TensorArray unstack(Tensor value, string name = null); - public TensorArray(TF_DataType dtype, Tensor size = default, bool? clear_after_read = null, bool? dynamic_size = null, - string tensor_array_name = null, Tensor handle = null, Tensor flow = null, - bool infer_shape = true, Shape element_shape = null, - bool colocate_with_first_write_call = true, string name = null) - { - _implementation = new _GraphTensorArray(dtype, - size: size, - dynamic_size: dynamic_size, - clear_after_read: clear_after_read, - tensor_array_name: tensor_array_name, - handle: handle, - flow: flow, - infer_shape: infer_shape, - element_shape: element_shape, - colocate_with_first_write_call: colocate_with_first_write_call, - name: name); - } + public abstract Tensor read(T index, string name = null); - public TensorArray unstack(Tensor value, string name = null) - => _implementation.unstack(value, name: name); + public abstract TensorArray write(int index, T value, string name = null); + public abstract TensorArray write(Tensor index, Tensor value, string name = null); - public Tensor read(Tensor index, string name = null) - => _implementation.read(index, name: name); - - public TensorArray write(Tensor index, Tensor value, string name = null) - => _implementation.write(index, value, name: name); - - public Tensor stack(string name = null) - => _implementation.stack(name: name); + public abstract Tensor stack(string name = null); + public abstract Tensor gather(Tensor indices, string name = null); } } diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs index b2d3d891..6a12ed20 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs @@ -77,5 +77,20 @@ namespace TensorFlowNET.UnitTest.ManagedAPI var r3 = tf.gather(p2, i2, axis: 1); Assert.AreEqual(new Shape(4,1,2), r3.shape); } + + /// + /// https://www.tensorflow.org/api_docs/python/tf/TensorArray + /// + [TestMethod] + public void TensorArray() + { + var ta = tf.TensorArray(tf.float32, size: 0, dynamic_size: true, clear_after_read: false); + ta.write(0, 10); + ta.write(1, 20); + ta.write(2, 30); + Assert.AreEqual(ta.read(0).numpy(), 10f); + Assert.AreEqual(ta.read(1).numpy(), 20f); + Assert.AreEqual(ta.read(2).numpy(), 30f); + } } }