diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs
index 39561990..26b29982 100644
--- a/src/TensorFlowNET.Core/Layers/Layer.cs
+++ b/src/TensorFlowNET.Core/Layers/Layer.cs
@@ -65,7 +65,9 @@ namespace Tensorflow.Layers
variable_scope scope_context_manager = null;
if (built)
{
-
+ scope_context_manager = tf.variable_scope(_scope,
+ reuse: true,
+ auxiliary_name_scope: false);
}
else
{
@@ -181,7 +183,7 @@ namespace Tensorflow.Layers
return _current_scope.original_name_scope;
}
- private void _set_scope(VariableScope scope = null)
+ protected void _set_scope(VariableScope scope = null)
{
if (_scope == null)
{
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs
index ab19a271..3eb2ee95 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs
@@ -14,12 +14,17 @@ namespace Tensorflow
/// Basic LSTM recurrent network cell.
/// The implementation is based on: http://arxiv.org/abs/1409.2329.
///
- public class BasicLSTMCell : LayerRnnCell
+ public class BasicLstmCell : LayerRnnCell
{
int _num_units;
float _forget_bias;
bool _state_is_tuple;
IActivation _activation;
+ LSTMStateTuple _state;
+ VariableV1 _kernel;
+ VariableV1 _bias;
+ string _WEIGHTS_VARIABLE_NAME = "kernel";
+ string _BIAS_VARIABLE_NAME = "bias";
///
/// Initialize the basic LSTM cell.
@@ -31,7 +36,7 @@ namespace Tensorflow
///
///
///
- public BasicLSTMCell(int num_units, float forget_bias = 1.0f, bool state_is_tuple = true,
+ public BasicLstmCell(int num_units, float forget_bias = 1.0f, bool state_is_tuple = true,
IActivation activation = null, bool? reuse = null, string name = null,
TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: reuse, name: name, dtype: dtype)
{
@@ -44,13 +49,123 @@ namespace Tensorflow
_activation = tf.nn.tanh();
}
- public LSTMStateTuple state_size
+ protected override void build(TensorShape input_shape)
+ {
+ var input_depth = input_shape.dims.Last();
+ var h_depth = _num_units;
+ _kernel = add_weight(_WEIGHTS_VARIABLE_NAME,
+ shape: new[] { input_depth + h_depth, 4 * _num_units });
+ _bias = add_weight(_BIAS_VARIABLE_NAME,
+ shape: new[] { 4 * _num_units },
+ initializer: tf.zeros_initializer);
+ built = true;
+ }
+
+ public Tensor[] __call__(Tensor inputs, LSTMStateTuple state)
+ {
+ _state = state;
+ return base.__call__(inputs);
+ }
+
+ ///
+ /// Long short-term memory cell (LSTM).
+ ///
+ ///
+ ///
+ ///
+ ///
+ protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
+ {
+ var one = constant_op.constant(1, dtype: dtypes.int32);
+ // Parameters of gates are concatenated into one multiply for efficiency.
+ Tensor c = null;
+ Tensor h = null;
+ if(_state_is_tuple)
+ (c, h) = ((Tensor)_state.c, (Tensor)_state.h);
+ else
+ {
+ // array_ops.split(value: state, num_or_size_splits: 2, axis: one);
+ throw new NotImplementedException("BasicLstmCell call");
+ }
+ var gate_inputs = math_ops.matmul(array_ops.concat(new[] { inputs, h }, 1), _kernel as RefVariable);
+ gate_inputs = nn_ops.bias_add(gate_inputs, _bias as RefVariable);
+
+ // i = input_gate, j = new_input, f = forget_gate, o = output_gate
+ var tensors = array_ops.split(value: gate_inputs, num_or_size_splits: 4, axis: one);
+ var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]);
+
+ var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype);
+ // Note that using `add` and `multiply` instead of `+` and `*` gives a
+ // performance improvement. So using those at the cost of readability.
+ var new_c = gen_math_ops.add(
+ math_ops.multiply(c, math_ops.sigmoid(gen_math_ops.add(f, forget_bias_tensor))),
+ math_ops.multiply(math_ops.sigmoid(i), _activation.Activate(j)));
+
+ var new_h = math_ops.multiply(_activation.Activate(new_c), math_ops.sigmoid(o));
+
+
+ if (_state_is_tuple)
+ return new[] { new_c, new_h };
+ else
+ return new[] { array_ops.concat(new[] { new_c, new_h }, 1) };
+ }
+
+ public override object get_initial_state(Tensor inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid)
+ {
+ if (inputs != null)
+ throw new NotImplementedException("get_initial_state input is not null");
+
+ return zero_state(batch_size, dtype);
+ }
+
+ ///
+ /// Return zero-filled state tensor(s).
+ ///
+ ///
+ ///
+ ///
+ private LSTMStateTuple zero_state(Tensor batch_size, TF_DataType dtype)
+ {
+ LSTMStateTuple output = null;
+ tf_with(ops.name_scope($"{GetType().Name}ZeroState", values: new { batch_size }), delegate
+ {
+ output = _zero_state_tensors(state_size, batch_size, dtype);
+ });
+
+ return output;
+ }
+
+ private LSTMStateTuple _zero_state_tensors(object state_size, Tensor batch_size, TF_DataType dtype)
+ {
+ if (state_size is LSTMStateTuple state_size_tuple)
+ {
+ var outputs = state_size_tuple.Flatten()
+ .Select(x => (int)x)
+ .Select(s =>
+ {
+ var c = rnn_cell_impl._concat(batch_size, s);
+ var size = array_ops.zeros(c, dtype: dtype);
+
+ var c_static = rnn_cell_impl._concat(batch_size, s, @static: true);
+ size.set_shape(c_static);
+
+ return size;
+ }).ToArray();
+
+ return new LSTMStateTuple(outputs[0], outputs[1]);
+ }
+
+ throw new NotImplementedException("_zero_state_tensors");
+ }
+
+ public override object state_size
{
get
{
- return _state_is_tuple ?
- new LSTMStateTuple(_num_units, _num_units) :
- (LSTMStateTuple)(2 * _num_units);
+ if (_state_is_tuple)
+ return new LSTMStateTuple(_num_units, _num_units);
+ else
+ return 2 * _num_units;
}
}
}
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs
index da528982..b93bea8d 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs
@@ -26,7 +26,7 @@ namespace Tensorflow
int _num_units;
Func _activation;
- public override LSTMStateTuple state_size => _num_units;
+ public override object state_size => _num_units;
public override int output_size => _num_units;
public VariableV1 _kernel;
string _WEIGHTS_VARIABLE_NAME = "kernel";
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/LSTMStateTuple.cs b/src/TensorFlowNET.Core/Operations/NnOps/LSTMStateTuple.cs
index 7539021b..f6bf5c6e 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/LSTMStateTuple.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/LSTMStateTuple.cs
@@ -12,15 +12,10 @@ namespace Tensorflow.Operations
///
/// Only used when `state_is_tuple=True`.
///
- public class LSTMStateTuple
+ public class LSTMStateTuple : ICanBeFlattened
{
- int c;
- int h;
-
- public LSTMStateTuple(int c)
- {
- this.c = c;
- }
+ public object c;
+ public object h;
public LSTMStateTuple(int c, int h)
{
@@ -28,14 +23,13 @@ namespace Tensorflow.Operations
this.h = h;
}
- public static implicit operator int(LSTMStateTuple tuple)
+ public LSTMStateTuple(Tensor c, Tensor h)
{
- return tuple.c;
+ this.c = c;
+ this.h = h;
}
- public static implicit operator LSTMStateTuple(int c)
- {
- return new LSTMStateTuple(c);
- }
+ public object[] Flatten()
+ => new[] { c, h };
}
}
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
index 4d277082..61d97cb9 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
@@ -49,7 +49,7 @@ namespace Tensorflow
/// difference between TF and Keras RNN cell.
///
protected bool _is_tf_rnn_cell = false;
- public virtual LSTMStateTuple state_size { get; }
+ public virtual object state_size { get; }
public virtual int output_size { get; }
@@ -64,7 +64,7 @@ namespace Tensorflow
_is_tf_rnn_cell = true;
}
- public virtual Tensor get_initial_state(Tensor inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid)
+ public virtual object get_initial_state(Tensor inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid)
{
if (inputs != null)
throw new NotImplementedException("get_initial_state input is not null");
@@ -78,11 +78,10 @@ namespace Tensorflow
///
///
///
- public Tensor zero_state(Tensor batch_size, TF_DataType dtype)
+ private Tensor zero_state(Tensor batch_size, TF_DataType dtype)
{
Tensor output = null;
- var state_size = this.state_size;
- tf_with(ops.name_scope($"{this.GetType().Name}ZeroState", values: new { batch_size }), delegate
+ tf_with(ops.name_scope($"{GetType().Name}ZeroState", values: new { batch_size }), delegate
{
output = _zero_state_tensors(state_size, batch_size, dtype);
});
@@ -90,20 +89,25 @@ namespace Tensorflow
return output;
}
- private Tensor _zero_state_tensors(int state_size, Tensor batch_size, TF_DataType dtype)
+ private Tensor _zero_state_tensors(object state_size, Tensor batch_size, TF_DataType dtype)
{
- var output = nest.map_structure(s =>
+ if(state_size is int state_size_int)
{
- var c = rnn_cell_impl._concat(batch_size, s);
- var size = array_ops.zeros(c, dtype: dtype);
+ var output = nest.map_structure(s =>
+ {
+ var c = rnn_cell_impl._concat(batch_size, s);
+ var size = array_ops.zeros(c, dtype: dtype);
- var c_static = rnn_cell_impl._concat(batch_size, s, @static: true);
- size.set_shape(c_static);
+ var c_static = rnn_cell_impl._concat(batch_size, s, @static: true);
+ size.set_shape(c_static);
- return size;
- }, state_size);
+ return size;
+ }, state_size_int);
- return output;
+ return output;
+ }
+
+ throw new NotImplementedException("_zero_state_tensors");
}
}
}
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
index a71d035a..5509ba2c 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
@@ -29,8 +29,8 @@ namespace Tensorflow.Operations
///
/// Creates a bidirectional recurrent neural network.
///
- public static void static_bidirectional_rnn(BasicLSTMCell cell_fw,
- BasicLSTMCell cell_bw,
+ public static (Tensor[], LSTMStateTuple, LSTMStateTuple) static_bidirectional_rnn(BasicLstmCell cell_fw,
+ BasicLstmCell cell_bw,
Tensor[] inputs,
Tensor initial_state_fw = null,
Tensor initial_state_bw = null,
@@ -41,12 +41,17 @@ namespace Tensorflow.Operations
if (inputs == null || inputs.Length == 0)
throw new ValueError("inputs must not be empty");
+ Tensor[] output_fw = null;
+ Tensor[] output_bw = null;
+ LSTMStateTuple output_state_fw = null;
+ LSTMStateTuple output_state_bw = null;
+
tf_with(tf.variable_scope(scope ?? "bidirectional_rnn"), delegate
{
// Forward direction
tf_with(tf.variable_scope("fw"), fw_scope =>
{
- static_rnn(
+ (output_fw, output_state_fw) = static_rnn(
cell_fw,
inputs,
initial_state_fw,
@@ -54,16 +59,48 @@ namespace Tensorflow.Operations
sequence_length,
scope: fw_scope);
});
+
+ // backward direction
+ tf_with(tf.variable_scope("bw"), bw_scope =>
+ {
+ var reversed_inputs = _reverse_seq(inputs, sequence_length);
+ (output_bw, output_state_bw) = static_rnn(
+ cell_bw,
+ reversed_inputs,
+ initial_state_bw,
+ dtype,
+ sequence_length,
+ scope: bw_scope);
+ });
});
+
+ output_bw = _reverse_seq(output_bw, sequence_length);
+
+ var flat_outputs = zip(output_fw, output_bw)
+ .Select(x => array_ops.concat(new[] { x.Item1, x.Item2 }, 1))
+ .ToArray();
+
+ return (flat_outputs, output_state_fw, output_state_bw);
}
- public static void static_rnn(BasicLSTMCell cell,
+ private static Tensor[] _reverse_seq(Tensor[] input_seq, Tensor lengths)
+ {
+ if (lengths == null)
+ return input_seq.Reverse().ToArray();
+
+ throw new NotImplementedException("_reverse_seq");
+ }
+
+ public static (Tensor[], LSTMStateTuple) static_rnn(BasicLstmCell cell,
Tensor[] inputs,
Tensor initial_state,
TF_DataType dtype = TF_DataType.DtInvalid,
Tensor sequence_length = null,
VariableScope scope = null)
{
+ List outputs = new List();
+ object state = null;
+
// Create a new scope in which the caching device is either
// determined by the parent scope, or is set to place the cached
// Variable using the same placement as for the rest of the RNN.
@@ -73,12 +110,12 @@ namespace Tensorflow.Operations
throw new NotImplementedException("static_rnn");
});
else
- tf_with(tf.variable_scope(scope), varscope =>
+ tf_with(tf.variable_scope(scope), scope1 =>
{
Dimension fixed_batch_size = null;
Dimension batch_size = null;
Tensor batch_size_tensor = null;
-
+ VariableScope varscope = scope1;
// Obtain the first sequence of the input
var first_input = inputs[0];
if (first_input.TensorShape.rank != 1)
@@ -108,14 +145,31 @@ namespace Tensorflow.Operations
else
batch_size_tensor = array_ops.shape(first_input)[0];
- Tensor state = null;
if (initial_state != null)
state = initial_state;
else
{
- cell.get_initial_state(batch_size: batch_size_tensor, dtype: dtype);
+ state = cell.get_initial_state(batch_size: batch_size_tensor, dtype: dtype);
+ }
+
+ Tensor output = null;
+ if (state is LSTMStateTuple state_tuple)
+ {
+ foreach (var (time, input_) in enumerate(inputs))
+ {
+ if (time > 0)
+ varscope.reuse_variables();
+ if (sequence_length != null)
+ throw new NotImplementedException("static_rnn");
+
+ var results = cell.__call__(input_, state_tuple);
+ (output, state_tuple) = (results[1], new LSTMStateTuple(results[0], results[1]));
+ outputs.Add(output);
+ }
}
});
+
+ return (outputs.ToArray(), state as LSTMStateTuple);
}
public static (Tensor, Tensor) dynamic_rnn(RnnCell cell, Tensor inputs_tensor,
@@ -145,7 +199,7 @@ namespace Tensorflow.Operations
if (initial_state != null)
state = initial_state;
else
- state = cell.get_initial_state(batch_size: batch_size, dtype: dtype);
+ state = cell.get_initial_state(batch_size: batch_size, dtype: dtype) as Tensor;
var inputs = nest.pack_sequence_as(structure: inputs_tensor, flat_sequence: flat_input);
diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs
index c487f478..f9f2f58f 100644
--- a/src/TensorFlowNET.Core/Operations/array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/array_ops.cs
@@ -604,6 +604,11 @@ namespace Tensorflow
return gen_array_ops.concat_v2(values, axis, name: name);
}
+ public static Tensor concat(Tensor[] values, Tensor axis, string name = "concat")
+ {
+ return gen_array_ops.concat_v2(values, axis, name: name);
+ }
+
public static Tensor concat(object[] values, int axis, string name = "concat")
{
return gen_array_ops.concat_v2(values, axis, name: name);
@@ -629,6 +634,16 @@ namespace Tensorflow
});
}
+ public static Tensor[] split(Tensor value, int num_or_size_splits, Tensor axis,
+ string name = "split")
+ {
+ var size_splits = ops.convert_to_tensor(num_or_size_splits);
+ return gen_array_ops.split(axis: axis,
+ num_split: num_or_size_splits,
+ value: value,
+ name: name);
+ }
+
public static Tensor slice(Tensor input, Tb begin, Ts size, string name = null)
=> gen_array_ops.slice(input, begin, size, name: name);
diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
index 29910d04..d151d024 100644
--- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
@@ -47,7 +47,7 @@ namespace Tensorflow
///
///
///
- public static Tensor concat_v2(T[] values, int axis, string name = null)
+ public static Tensor concat_v2(T[] values, Ta axis, string name = null)
{
var _op = _op_def_lib._apply_op_helper("ConcatV2", name: name, args: new { values, axis });
diff --git a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj
index bf508a78..f374a2fe 100644
--- a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj
+++ b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj
@@ -5,7 +5,7 @@
TensorFlow.NET
Tensorflow
1.14.1
- 0.12.1
+ 0.13.0
Haiping Chen, Meinrad Recheis, Eli Belash
SciSharp STACK
true
@@ -18,14 +18,16 @@
Google's TensorFlow full binding in .NET Standard.
Building, training and infering deep learning models.
https://tensorflownet.readthedocs.io
- 0.12.1.0
- Changes since v0.11.0:
+ 0.13.0.0
+ Changes since v0.12.0:
1: Add ICanBeFlattened for nest.flatten2.
2: Complete the WhileContext.
3: Add tf.nn.rnn_cell.BasicRNNCell and tf.nn.dynamic_rnn.
-4: Add EstimatorSpec.
+4: Add EstimatorSpec.
+5: Add rnn.static_rnn.
+6: Add array_grad._SplitGrad().
7.3
- 0.12.1.0
+ 0.13.0.0
LICENSE
true
true
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs
index 846db42d..b5fdde48 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs
@@ -7,20 +7,6 @@ namespace Tensorflow
{
public partial class Tensor
{
- ///
- /// Issue unresolved, will cause name_scope problem.
- ///
- ///
- /*public static implicit operator Tensor(double scalar)
- {
- return constant_op.constant(scalar);
- }*/
-
- /*public static implicit operator Tensor(int scalar)
- {
- return constant_op.constant(scalar);
- }*/
-
public static implicit operator IntPtr(Tensor tensor)
{
if (tensor._handle == IntPtr.Zero)
diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs
index 7dbacea0..54149fe1 100644
--- a/src/TensorFlowNET.Core/Util/nest.py.cs
+++ b/src/TensorFlowNET.Core/Util/nest.py.cs
@@ -526,14 +526,6 @@ namespace Tensorflow.Util
return pack_sequence_as(structure, mapped_flat_structure) as Tensor;
}
- public static Tensor map_structure2(Func func, T structure)
- {
- var flat_structure = flatten(structure);
- var mapped_flat_structure = flat_structure.Select(func).ToList();
-
- return pack_sequence_as(structure, mapped_flat_structure) as Tensor;
- }
-
///
/// Same as map_structure, but with only one structure (no combining of multiple structures)
///
diff --git a/src/TensorFlowNET.Core/Variables/VariableScope.cs b/src/TensorFlowNET.Core/Variables/VariableScope.cs
index 52766e4f..68c75ca3 100644
--- a/src/TensorFlowNET.Core/Variables/VariableScope.cs
+++ b/src/TensorFlowNET.Core/Variables/VariableScope.cs
@@ -74,5 +74,10 @@ namespace Tensorflow
aggregation: aggregation) as RefVariable;
});
}
+
+ public void reuse_variables()
+ {
+ _reuse = _ReuseMode.AUTO_REUSE;
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Variables/_ReuseMode.cs b/src/TensorFlowNET.Core/Variables/_ReuseMode.cs
index e63e51f7..9344e824 100644
--- a/src/TensorFlowNET.Core/Variables/_ReuseMode.cs
+++ b/src/TensorFlowNET.Core/Variables/_ReuseMode.cs
@@ -5,6 +5,7 @@
///
public enum _ReuseMode
{
+ NOT_REUSE = 0,
// Indicates that variables are to be fetched if they already exist or
// otherwise created.
AUTO_REUSE = 1