diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs
index 790e391e..9f2b493c 100644
--- a/src/TensorFlowNET.Core/APIs/tf.math.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.math.cs
@@ -251,7 +251,7 @@ namespace Tensorflow
/// greater than clip_value_max are set to clip_value_max.
///
public Tensor clip_by_value (Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = "ClipByValue")
- => gen_ops.clip_by_value(t, clip_value_min, clip_value_max, name);
+ => clip_ops.clip_by_value(t, clip_value_min, clip_value_max, name);
public Tensor sub(Tensor a, Tensor b)
=> gen_math_ops.sub(a, b);
diff --git a/src/TensorFlowNET.Core/Framework/tensor_shape.cs b/src/TensorFlowNET.Core/Framework/tensor_shape.cs
index d4e2f6cd..06d80972 100644
--- a/src/TensorFlowNET.Core/Framework/tensor_shape.cs
+++ b/src/TensorFlowNET.Core/Framework/tensor_shape.cs
@@ -24,6 +24,16 @@ namespace Tensorflow.Framework
}
}
+ public static Dimension dimension_at_index(TensorShape shape, int index)
+ {
+ return shape.rank < 0 ?
+ new Dimension(-1) :
+ new Dimension(shape.dims[index]);
+ }
+
+ public static int dimension_value(Dimension dimension)
+ => dimension.value;
+
public static TensorShape as_shape(this Shape shape)
=> new TensorShape(shape.Dimensions);
}
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs
new file mode 100644
index 00000000..ab19a271
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs
@@ -0,0 +1,57 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using static Tensorflow.Binding;
+using Tensorflow.Operations.Activation;
+using Tensorflow.Keras.Engine;
+using Tensorflow.Operations;
+
+namespace Tensorflow
+{
+ ///
+ /// Basic LSTM recurrent network cell.
+ /// The implementation is based on: http://arxiv.org/abs/1409.2329.
+ ///
+ public class BasicLSTMCell : LayerRnnCell
+ {
+ int _num_units;
+ float _forget_bias;
+ bool _state_is_tuple;
+ IActivation _activation;
+
+ ///
+ /// Initialize the basic LSTM cell.
+ ///
+ /// The number of units in the LSTM cell.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public BasicLSTMCell(int num_units, float forget_bias = 1.0f, bool state_is_tuple = true,
+ IActivation activation = null, bool? reuse = null, string name = null,
+ TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: reuse, name: name, dtype: dtype)
+ {
+ input_spec = new InputSpec(ndim: 2);
+ _num_units = num_units;
+ _forget_bias = forget_bias;
+ _state_is_tuple = state_is_tuple;
+ _activation = activation;
+ if (_activation == null)
+ _activation = tf.nn.tanh();
+ }
+
+ public LSTMStateTuple state_size
+ {
+ get
+ {
+ return _state_is_tuple ?
+ new LSTMStateTuple(_num_units, _num_units) :
+ (LSTMStateTuple)(2 * _num_units);
+ }
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs
similarity index 96%
rename from src/TensorFlowNET.Core/Operations/BasicRNNCell.cs
rename to src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs
index 69f86349..da528982 100644
--- a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs
@@ -16,6 +16,7 @@
using System;
using Tensorflow.Keras.Engine;
+using Tensorflow.Operations;
using static Tensorflow.Binding;
namespace Tensorflow
@@ -25,7 +26,7 @@ namespace Tensorflow
int _num_units;
Func _activation;
- public override int state_size => _num_units;
+ public override LSTMStateTuple state_size => _num_units;
public override int output_size => _num_units;
public VariableV1 _kernel;
string _WEIGHTS_VARIABLE_NAME = "kernel";
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/LSTMStateTuple.cs b/src/TensorFlowNET.Core/Operations/NnOps/LSTMStateTuple.cs
new file mode 100644
index 00000000..7539021b
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/NnOps/LSTMStateTuple.cs
@@ -0,0 +1,41 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Operations
+{
+ ///
+ /// Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state.
+ ///
+ /// Stores two elements: `(c, h)`, in that order. Where `c` is the hidden state
+ /// and `h` is the output.
+ ///
+ /// Only used when `state_is_tuple=True`.
+ ///
+ public class LSTMStateTuple
+ {
+ int c;
+ int h;
+
+ public LSTMStateTuple(int c)
+ {
+ this.c = c;
+ }
+
+ public LSTMStateTuple(int c, int h)
+ {
+ this.c = c;
+ this.h = h;
+ }
+
+ public static implicit operator int(LSTMStateTuple tuple)
+ {
+ return tuple.c;
+ }
+
+ public static implicit operator LSTMStateTuple(int c)
+ {
+ return new LSTMStateTuple(c);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/LayerRNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/LayerRNNCell.cs
similarity index 100%
rename from src/TensorFlowNET.Core/Operations/LayerRNNCell.cs
rename to src/TensorFlowNET.Core/Operations/NnOps/LayerRNNCell.cs
diff --git a/src/TensorFlowNET.Core/Operations/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
similarity index 98%
rename from src/TensorFlowNET.Core/Operations/RNNCell.cs
rename to src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
index 9902cd41..4d277082 100644
--- a/src/TensorFlowNET.Core/Operations/RNNCell.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
@@ -49,7 +49,7 @@ namespace Tensorflow
/// difference between TF and Keras RNN cell.
///
protected bool _is_tf_rnn_cell = false;
- public virtual int state_size { get; }
+ public virtual LSTMStateTuple state_size { get; }
public virtual int output_size { get; }
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
index 48af7d58..a71d035a 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
@@ -18,13 +18,106 @@ using NumSharp;
using System;
using System.Collections.Generic;
using System.Linq;
+using Tensorflow.Framework;
using Tensorflow.Util;
using static Tensorflow.Binding;
namespace Tensorflow.Operations
{
- internal class rnn
+ public class rnn
{
+ ///
+ /// Creates a bidirectional recurrent neural network.
+ ///
+ public static void static_bidirectional_rnn(BasicLSTMCell cell_fw,
+ BasicLSTMCell cell_bw,
+ Tensor[] inputs,
+ Tensor initial_state_fw = null,
+ Tensor initial_state_bw = null,
+ TF_DataType dtype = TF_DataType.DtInvalid,
+ Tensor sequence_length = null,
+ string scope = null)
+ {
+ if (inputs == null || inputs.Length == 0)
+ throw new ValueError("inputs must not be empty");
+
+ tf_with(tf.variable_scope(scope ?? "bidirectional_rnn"), delegate
+ {
+ // Forward direction
+ tf_with(tf.variable_scope("fw"), fw_scope =>
+ {
+ static_rnn(
+ cell_fw,
+ inputs,
+ initial_state_fw,
+ dtype,
+ sequence_length,
+ scope: fw_scope);
+ });
+ });
+ }
+
+ public static void static_rnn(BasicLSTMCell cell,
+ Tensor[] inputs,
+ Tensor initial_state,
+ TF_DataType dtype = TF_DataType.DtInvalid,
+ Tensor sequence_length = null,
+ VariableScope scope = null)
+ {
+ // Create a new scope in which the caching device is either
+ // determined by the parent scope, or is set to place the cached
+ // Variable using the same placement as for the rest of the RNN.
+ if (scope == null)
+ tf_with(tf.variable_scope("rnn"), varscope =>
+ {
+ throw new NotImplementedException("static_rnn");
+ });
+ else
+ tf_with(tf.variable_scope(scope), varscope =>
+ {
+ Dimension fixed_batch_size = null;
+ Dimension batch_size = null;
+ Tensor batch_size_tensor = null;
+
+ // Obtain the first sequence of the input
+ var first_input = inputs[0];
+ if (first_input.TensorShape.rank != 1)
+ {
+ var input_shape = first_input.TensorShape.with_rank_at_least(2);
+ fixed_batch_size = input_shape.dims[0];
+ var flat_inputs = nest.flatten2(inputs);
+ foreach (var flat_input in flat_inputs)
+ {
+ input_shape = flat_input.TensorShape.with_rank_at_least(2);
+ batch_size = tensor_shape.dimension_at_index(input_shape, 0);
+ var input_size = input_shape[1];
+ fixed_batch_size.merge_with(batch_size);
+ foreach (var (i, size) in enumerate(input_size.dims))
+ {
+ if (size < 0)
+ throw new ValueError($"Input size (dimension {i} of inputs) must be accessible via " +
+ "shape inference, but saw value None.");
+ }
+ }
+ }
+ else
+ fixed_batch_size = first_input.TensorShape.with_rank_at_least(1).dims[0];
+
+ if (tensor_shape.dimension_value(fixed_batch_size) >= 0)
+ batch_size = tensor_shape.dimension_value(fixed_batch_size);
+ else
+ batch_size_tensor = array_ops.shape(first_input)[0];
+
+ Tensor state = null;
+ if (initial_state != null)
+ state = initial_state;
+ else
+ {
+ cell.get_initial_state(batch_size: batch_size_tensor, dtype: dtype);
+ }
+ });
+ }
+
public static (Tensor, Tensor) dynamic_rnn(RnnCell cell, Tensor inputs_tensor,
Tensor sequence_length = null, Tensor initial_state = null,
TF_DataType dtype = TF_DataType.DtInvalid,
diff --git a/src/TensorFlowNET.Core/Operations/clip_ops.cs b/src/TensorFlowNET.Core/Operations/clip_ops.cs
new file mode 100644
index 00000000..701664f4
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/clip_ops.cs
@@ -0,0 +1,45 @@
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using static Tensorflow.Binding;
+
+namespace Tensorflow
+{
+ public class clip_ops
+ {
+ public static Tensor clip_by_value(Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = null)
+ {
+ return tf_with(ops.name_scope(name, "clip_by_value", new { t, clip_value_min, clip_value_max }), delegate
+ {
+ var values = ops.convert_to_tensor(t, name: "t");
+ // Go through list of tensors, for each value in each tensor clip
+ var t_min = math_ops.minimum(values, clip_value_max);
+ // Assert that the shape is compatible with the initial shape,
+ // to prevent unintentional broadcasting.
+ _ = values.TensorShape.merge_with(t_min.shape);
+ var t_max = math_ops.maximum(t_min, clip_value_min, name: name);
+ _ = values.TensorShape.merge_with(t_max.shape);
+
+ return t_max;
+ });
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj
index 39279808..bf508a78 100644
--- a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj
+++ b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj
@@ -1,7 +1,7 @@
- net472;netstandard2.0
+ netstandard2.0
TensorFlow.NET
Tensorflow
1.14.1
diff --git a/src/TensorFlowNET.Core/Tensors/Dimension.cs b/src/TensorFlowNET.Core/Tensors/Dimension.cs
index 58520270..878ba5ae 100644
--- a/src/TensorFlowNET.Core/Tensors/Dimension.cs
+++ b/src/TensorFlowNET.Core/Tensors/Dimension.cs
@@ -22,6 +22,12 @@ namespace Tensorflow
return new Dimension(_value);
}
+ public static implicit operator Dimension(int value)
+ => new Dimension(value);
+
+ public static implicit operator int(Dimension dimension)
+ => dimension.value;
+
public override string ToString() => $"Dimension({_value})";
}
}
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs
index 67474eb9..99fba404 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs
@@ -162,9 +162,9 @@ namespace Tensorflow
using (var status = new Status())
{
if (value == null)
- c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status);
+ c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, status);
else
- c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status);
+ c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status);
status.Check(true);
}