diff --git a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs b/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs index fdcc03ea..69f86349 100644 --- a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs @@ -66,12 +66,14 @@ namespace Tensorflow built = true; } - protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) + protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) { // Most basic RNN: output = new_state = act(W * input + U * state + B). var concat = array_ops.concat(new[] { inputs, state }, 1); var gate_inputs = math_ops.matmul(concat, _kernel as RefVariable); - return (inputs, inputs); + gate_inputs = nn_ops.bias_add(gate_inputs, _bias as RefVariable); + var output = _activation(gate_inputs, null); + return new[] { output, output }; } } } diff --git a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs index 56ac277e..ea701afc 100644 --- a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs +++ b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs @@ -22,7 +22,7 @@ using static Tensorflow.Binding; namespace Tensorflow.Operations { - internal class _GraphTensorArray + public class _GraphTensorArray { internal TF_DataType _dtype; public TF_DataType dtype => _dtype; @@ -174,5 +174,57 @@ namespace Tensorflow.Operations return value; } + + public TensorArray write(Tensor index, Tensor value, string name = null) + { + return tf_with(ops.name_scope(name, "TensorArrayWrite", new { _handle, index, value }), delegate + { + value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); + _maybe_colocate_with(value); + var flow_out = gen_data_flow_ops.tensor_array_write_v3( + handle: _handle, + index: index, + value: value, + flow_in: _flow, + name: name); + + return tensor_array_ops.build_ta_with_new_flow(this, flow_out); + }); + } + + private Tensor size(string name = null) + { + return gen_data_flow_ops.tensor_array_size_v3(_handle, _flow, name: name); + } + + public Tensor stack(string name = null) + { + ops.colocate_with(_handle); + return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate + { + return gather(math_ops.range(0, size()), name: name); + }); + } + + public Tensor gather(Tensor indices, string name = null) + { + var element_shape = new TensorShape(); + + if (_element_shape.Count > 0) + element_shape = _element_shape[0]; + + var value = gen_data_flow_ops.tensor_array_gather_v3( + handle: _handle, + indices: indices, + flow_in: _flow, + dtype: _dtype, + name: name, + element_shape: element_shape); + + //if (element_shape != null) + //value.set_shape(-1, element_shape.dims); + + return value; + } } } diff --git a/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs b/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs index 8ce3b5c7..59496943 100644 --- a/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Operations; namespace Tensorflow { @@ -29,5 +30,23 @@ namespace Tensorflow new_impl._element_shape = impl._element_shape; return new_ta; } + + public static TensorArray build_ta_with_new_flow(_GraphTensorArray old_ta, Tensor flow) + { + var impl = old_ta; + + var new_ta = new TensorArray( + dtype: impl.dtype, + handle: impl.handle, + flow: flow, + infer_shape: impl.infer_shape, + colocate_with_first_write_call: impl.colocate_with_first_write_call); + + var new_impl = new_ta._implementation; + new_impl._dynamic_size = impl._dynamic_size; + new_impl._colocate_with = impl._colocate_with; + new_impl._element_shape = impl._element_shape; + return new_ta; + } } }