diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index 39561990..26b29982 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -65,7 +65,9 @@ namespace Tensorflow.Layers variable_scope scope_context_manager = null; if (built) { - + scope_context_manager = tf.variable_scope(_scope, + reuse: true, + auxiliary_name_scope: false); } else { @@ -181,7 +183,7 @@ namespace Tensorflow.Layers return _current_scope.original_name_scope; } - private void _set_scope(VariableScope scope = null) + protected void _set_scope(VariableScope scope = null) { if (_scope == null) { diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs index ab19a271..3eb2ee95 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs @@ -14,12 +14,17 @@ namespace Tensorflow /// Basic LSTM recurrent network cell. /// The implementation is based on: http://arxiv.org/abs/1409.2329. /// - public class BasicLSTMCell : LayerRnnCell + public class BasicLstmCell : LayerRnnCell { int _num_units; float _forget_bias; bool _state_is_tuple; IActivation _activation; + LSTMStateTuple _state; + VariableV1 _kernel; + VariableV1 _bias; + string _WEIGHTS_VARIABLE_NAME = "kernel"; + string _BIAS_VARIABLE_NAME = "bias"; /// /// Initialize the basic LSTM cell. @@ -31,7 +36,7 @@ namespace Tensorflow /// /// /// - public BasicLSTMCell(int num_units, float forget_bias = 1.0f, bool state_is_tuple = true, + public BasicLstmCell(int num_units, float forget_bias = 1.0f, bool state_is_tuple = true, IActivation activation = null, bool? reuse = null, string name = null, TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: reuse, name: name, dtype: dtype) { @@ -44,13 +49,123 @@ namespace Tensorflow _activation = tf.nn.tanh(); } - public LSTMStateTuple state_size + protected override void build(TensorShape input_shape) + { + var input_depth = input_shape.dims.Last(); + var h_depth = _num_units; + _kernel = add_weight(_WEIGHTS_VARIABLE_NAME, + shape: new[] { input_depth + h_depth, 4 * _num_units }); + _bias = add_weight(_BIAS_VARIABLE_NAME, + shape: new[] { 4 * _num_units }, + initializer: tf.zeros_initializer); + built = true; + } + + public Tensor[] __call__(Tensor inputs, LSTMStateTuple state) + { + _state = state; + return base.__call__(inputs); + } + + /// + /// Long short-term memory cell (LSTM). + /// + /// + /// + /// + /// + protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) + { + var one = constant_op.constant(1, dtype: dtypes.int32); + // Parameters of gates are concatenated into one multiply for efficiency. + Tensor c = null; + Tensor h = null; + if(_state_is_tuple) + (c, h) = ((Tensor)_state.c, (Tensor)_state.h); + else + { + // array_ops.split(value: state, num_or_size_splits: 2, axis: one); + throw new NotImplementedException("BasicLstmCell call"); + } + var gate_inputs = math_ops.matmul(array_ops.concat(new[] { inputs, h }, 1), _kernel as RefVariable); + gate_inputs = nn_ops.bias_add(gate_inputs, _bias as RefVariable); + + // i = input_gate, j = new_input, f = forget_gate, o = output_gate + var tensors = array_ops.split(value: gate_inputs, num_or_size_splits: 4, axis: one); + var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]); + + var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype); + // Note that using `add` and `multiply` instead of `+` and `*` gives a + // performance improvement. So using those at the cost of readability. + var new_c = gen_math_ops.add( + math_ops.multiply(c, math_ops.sigmoid(gen_math_ops.add(f, forget_bias_tensor))), + math_ops.multiply(math_ops.sigmoid(i), _activation.Activate(j))); + + var new_h = math_ops.multiply(_activation.Activate(new_c), math_ops.sigmoid(o)); + + + if (_state_is_tuple) + return new[] { new_c, new_h }; + else + return new[] { array_ops.concat(new[] { new_c, new_h }, 1) }; + } + + public override object get_initial_state(Tensor inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) + { + if (inputs != null) + throw new NotImplementedException("get_initial_state input is not null"); + + return zero_state(batch_size, dtype); + } + + /// + /// Return zero-filled state tensor(s). + /// + /// + /// + /// + private LSTMStateTuple zero_state(Tensor batch_size, TF_DataType dtype) + { + LSTMStateTuple output = null; + tf_with(ops.name_scope($"{GetType().Name}ZeroState", values: new { batch_size }), delegate + { + output = _zero_state_tensors(state_size, batch_size, dtype); + }); + + return output; + } + + private LSTMStateTuple _zero_state_tensors(object state_size, Tensor batch_size, TF_DataType dtype) + { + if (state_size is LSTMStateTuple state_size_tuple) + { + var outputs = state_size_tuple.Flatten() + .Select(x => (int)x) + .Select(s => + { + var c = rnn_cell_impl._concat(batch_size, s); + var size = array_ops.zeros(c, dtype: dtype); + + var c_static = rnn_cell_impl._concat(batch_size, s, @static: true); + size.set_shape(c_static); + + return size; + }).ToArray(); + + return new LSTMStateTuple(outputs[0], outputs[1]); + } + + throw new NotImplementedException("_zero_state_tensors"); + } + + public override object state_size { get { - return _state_is_tuple ? - new LSTMStateTuple(_num_units, _num_units) : - (LSTMStateTuple)(2 * _num_units); + if (_state_is_tuple) + return new LSTMStateTuple(_num_units, _num_units); + else + return 2 * _num_units; } } } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs index da528982..b93bea8d 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs @@ -26,7 +26,7 @@ namespace Tensorflow int _num_units; Func _activation; - public override LSTMStateTuple state_size => _num_units; + public override object state_size => _num_units; public override int output_size => _num_units; public VariableV1 _kernel; string _WEIGHTS_VARIABLE_NAME = "kernel"; diff --git a/src/TensorFlowNET.Core/Operations/NnOps/LSTMStateTuple.cs b/src/TensorFlowNET.Core/Operations/NnOps/LSTMStateTuple.cs index 7539021b..f6bf5c6e 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/LSTMStateTuple.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/LSTMStateTuple.cs @@ -12,15 +12,10 @@ namespace Tensorflow.Operations /// /// Only used when `state_is_tuple=True`. /// - public class LSTMStateTuple + public class LSTMStateTuple : ICanBeFlattened { - int c; - int h; - - public LSTMStateTuple(int c) - { - this.c = c; - } + public object c; + public object h; public LSTMStateTuple(int c, int h) { @@ -28,14 +23,13 @@ namespace Tensorflow.Operations this.h = h; } - public static implicit operator int(LSTMStateTuple tuple) + public LSTMStateTuple(Tensor c, Tensor h) { - return tuple.c; + this.c = c; + this.h = h; } - public static implicit operator LSTMStateTuple(int c) - { - return new LSTMStateTuple(c); - } + public object[] Flatten() + => new[] { c, h }; } } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index 4d277082..61d97cb9 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -49,7 +49,7 @@ namespace Tensorflow /// difference between TF and Keras RNN cell. /// protected bool _is_tf_rnn_cell = false; - public virtual LSTMStateTuple state_size { get; } + public virtual object state_size { get; } public virtual int output_size { get; } @@ -64,7 +64,7 @@ namespace Tensorflow _is_tf_rnn_cell = true; } - public virtual Tensor get_initial_state(Tensor inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) + public virtual object get_initial_state(Tensor inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) { if (inputs != null) throw new NotImplementedException("get_initial_state input is not null"); @@ -78,11 +78,10 @@ namespace Tensorflow /// /// /// - public Tensor zero_state(Tensor batch_size, TF_DataType dtype) + private Tensor zero_state(Tensor batch_size, TF_DataType dtype) { Tensor output = null; - var state_size = this.state_size; - tf_with(ops.name_scope($"{this.GetType().Name}ZeroState", values: new { batch_size }), delegate + tf_with(ops.name_scope($"{GetType().Name}ZeroState", values: new { batch_size }), delegate { output = _zero_state_tensors(state_size, batch_size, dtype); }); @@ -90,20 +89,25 @@ namespace Tensorflow return output; } - private Tensor _zero_state_tensors(int state_size, Tensor batch_size, TF_DataType dtype) + private Tensor _zero_state_tensors(object state_size, Tensor batch_size, TF_DataType dtype) { - var output = nest.map_structure(s => + if(state_size is int state_size_int) { - var c = rnn_cell_impl._concat(batch_size, s); - var size = array_ops.zeros(c, dtype: dtype); + var output = nest.map_structure(s => + { + var c = rnn_cell_impl._concat(batch_size, s); + var size = array_ops.zeros(c, dtype: dtype); - var c_static = rnn_cell_impl._concat(batch_size, s, @static: true); - size.set_shape(c_static); + var c_static = rnn_cell_impl._concat(batch_size, s, @static: true); + size.set_shape(c_static); - return size; - }, state_size); + return size; + }, state_size_int); - return output; + return output; + } + + throw new NotImplementedException("_zero_state_tensors"); } } } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs index a71d035a..5509ba2c 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs @@ -29,8 +29,8 @@ namespace Tensorflow.Operations /// /// Creates a bidirectional recurrent neural network. /// - public static void static_bidirectional_rnn(BasicLSTMCell cell_fw, - BasicLSTMCell cell_bw, + public static (Tensor[], LSTMStateTuple, LSTMStateTuple) static_bidirectional_rnn(BasicLstmCell cell_fw, + BasicLstmCell cell_bw, Tensor[] inputs, Tensor initial_state_fw = null, Tensor initial_state_bw = null, @@ -41,12 +41,17 @@ namespace Tensorflow.Operations if (inputs == null || inputs.Length == 0) throw new ValueError("inputs must not be empty"); + Tensor[] output_fw = null; + Tensor[] output_bw = null; + LSTMStateTuple output_state_fw = null; + LSTMStateTuple output_state_bw = null; + tf_with(tf.variable_scope(scope ?? "bidirectional_rnn"), delegate { // Forward direction tf_with(tf.variable_scope("fw"), fw_scope => { - static_rnn( + (output_fw, output_state_fw) = static_rnn( cell_fw, inputs, initial_state_fw, @@ -54,16 +59,48 @@ namespace Tensorflow.Operations sequence_length, scope: fw_scope); }); + + // backward direction + tf_with(tf.variable_scope("bw"), bw_scope => + { + var reversed_inputs = _reverse_seq(inputs, sequence_length); + (output_bw, output_state_bw) = static_rnn( + cell_bw, + reversed_inputs, + initial_state_bw, + dtype, + sequence_length, + scope: bw_scope); + }); }); + + output_bw = _reverse_seq(output_bw, sequence_length); + + var flat_outputs = zip(output_fw, output_bw) + .Select(x => array_ops.concat(new[] { x.Item1, x.Item2 }, 1)) + .ToArray(); + + return (flat_outputs, output_state_fw, output_state_bw); } - public static void static_rnn(BasicLSTMCell cell, + private static Tensor[] _reverse_seq(Tensor[] input_seq, Tensor lengths) + { + if (lengths == null) + return input_seq.Reverse().ToArray(); + + throw new NotImplementedException("_reverse_seq"); + } + + public static (Tensor[], LSTMStateTuple) static_rnn(BasicLstmCell cell, Tensor[] inputs, Tensor initial_state, TF_DataType dtype = TF_DataType.DtInvalid, Tensor sequence_length = null, VariableScope scope = null) { + List outputs = new List(); + object state = null; + // Create a new scope in which the caching device is either // determined by the parent scope, or is set to place the cached // Variable using the same placement as for the rest of the RNN. @@ -73,12 +110,12 @@ namespace Tensorflow.Operations throw new NotImplementedException("static_rnn"); }); else - tf_with(tf.variable_scope(scope), varscope => + tf_with(tf.variable_scope(scope), scope1 => { Dimension fixed_batch_size = null; Dimension batch_size = null; Tensor batch_size_tensor = null; - + VariableScope varscope = scope1; // Obtain the first sequence of the input var first_input = inputs[0]; if (first_input.TensorShape.rank != 1) @@ -108,14 +145,31 @@ namespace Tensorflow.Operations else batch_size_tensor = array_ops.shape(first_input)[0]; - Tensor state = null; if (initial_state != null) state = initial_state; else { - cell.get_initial_state(batch_size: batch_size_tensor, dtype: dtype); + state = cell.get_initial_state(batch_size: batch_size_tensor, dtype: dtype); + } + + Tensor output = null; + if (state is LSTMStateTuple state_tuple) + { + foreach (var (time, input_) in enumerate(inputs)) + { + if (time > 0) + varscope.reuse_variables(); + if (sequence_length != null) + throw new NotImplementedException("static_rnn"); + + var results = cell.__call__(input_, state_tuple); + (output, state_tuple) = (results[1], new LSTMStateTuple(results[0], results[1])); + outputs.Add(output); + } } }); + + return (outputs.ToArray(), state as LSTMStateTuple); } public static (Tensor, Tensor) dynamic_rnn(RnnCell cell, Tensor inputs_tensor, @@ -145,7 +199,7 @@ namespace Tensorflow.Operations if (initial_state != null) state = initial_state; else - state = cell.get_initial_state(batch_size: batch_size, dtype: dtype); + state = cell.get_initial_state(batch_size: batch_size, dtype: dtype) as Tensor; var inputs = nest.pack_sequence_as(structure: inputs_tensor, flat_sequence: flat_input); diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index c487f478..f9f2f58f 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -604,6 +604,11 @@ namespace Tensorflow return gen_array_ops.concat_v2(values, axis, name: name); } + public static Tensor concat(Tensor[] values, Tensor axis, string name = "concat") + { + return gen_array_ops.concat_v2(values, axis, name: name); + } + public static Tensor concat(object[] values, int axis, string name = "concat") { return gen_array_ops.concat_v2(values, axis, name: name); @@ -629,6 +634,16 @@ namespace Tensorflow }); } + public static Tensor[] split(Tensor value, int num_or_size_splits, Tensor axis, + string name = "split") + { + var size_splits = ops.convert_to_tensor(num_or_size_splits); + return gen_array_ops.split(axis: axis, + num_split: num_or_size_splits, + value: value, + name: name); + } + public static Tensor slice(Tensor input, Tb begin, Ts size, string name = null) => gen_array_ops.slice(input, begin, size, name: name); diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 29910d04..d151d024 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -47,7 +47,7 @@ namespace Tensorflow /// /// /// - public static Tensor concat_v2(T[] values, int axis, string name = null) + public static Tensor concat_v2(T[] values, Ta axis, string name = null) { var _op = _op_def_lib._apply_op_helper("ConcatV2", name: name, args: new { values, axis }); diff --git a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj index bf508a78..f374a2fe 100644 --- a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj +++ b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj @@ -5,7 +5,7 @@ TensorFlow.NET Tensorflow 1.14.1 - 0.12.1 + 0.13.0 Haiping Chen, Meinrad Recheis, Eli Belash SciSharp STACK true @@ -18,14 +18,16 @@ Google's TensorFlow full binding in .NET Standard. Building, training and infering deep learning models. https://tensorflownet.readthedocs.io - 0.12.1.0 - Changes since v0.11.0: + 0.13.0.0 + Changes since v0.12.0: 1: Add ICanBeFlattened for nest.flatten2. 2: Complete the WhileContext. 3: Add tf.nn.rnn_cell.BasicRNNCell and tf.nn.dynamic_rnn. -4: Add EstimatorSpec. +4: Add EstimatorSpec. +5: Add rnn.static_rnn. +6: Add array_grad._SplitGrad(). 7.3 - 0.12.1.0 + 0.13.0.0 LICENSE true true diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs index 846db42d..b5fdde48 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs @@ -7,20 +7,6 @@ namespace Tensorflow { public partial class Tensor { - /// - /// Issue unresolved, will cause name_scope problem. - /// - /// - /*public static implicit operator Tensor(double scalar) - { - return constant_op.constant(scalar); - }*/ - - /*public static implicit operator Tensor(int scalar) - { - return constant_op.constant(scalar); - }*/ - public static implicit operator IntPtr(Tensor tensor) { if (tensor._handle == IntPtr.Zero) diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs index 7dbacea0..54149fe1 100644 --- a/src/TensorFlowNET.Core/Util/nest.py.cs +++ b/src/TensorFlowNET.Core/Util/nest.py.cs @@ -526,14 +526,6 @@ namespace Tensorflow.Util return pack_sequence_as(structure, mapped_flat_structure) as Tensor; } - public static Tensor map_structure2(Func func, T structure) - { - var flat_structure = flatten(structure); - var mapped_flat_structure = flat_structure.Select(func).ToList(); - - return pack_sequence_as(structure, mapped_flat_structure) as Tensor; - } - /// /// Same as map_structure, but with only one structure (no combining of multiple structures) /// diff --git a/src/TensorFlowNET.Core/Variables/VariableScope.cs b/src/TensorFlowNET.Core/Variables/VariableScope.cs index 52766e4f..68c75ca3 100644 --- a/src/TensorFlowNET.Core/Variables/VariableScope.cs +++ b/src/TensorFlowNET.Core/Variables/VariableScope.cs @@ -74,5 +74,10 @@ namespace Tensorflow aggregation: aggregation) as RefVariable; }); } + + public void reuse_variables() + { + _reuse = _ReuseMode.AUTO_REUSE; + } } } diff --git a/src/TensorFlowNET.Core/Variables/_ReuseMode.cs b/src/TensorFlowNET.Core/Variables/_ReuseMode.cs index e63e51f7..9344e824 100644 --- a/src/TensorFlowNET.Core/Variables/_ReuseMode.cs +++ b/src/TensorFlowNET.Core/Variables/_ReuseMode.cs @@ -5,6 +5,7 @@ /// public enum _ReuseMode { + NOT_REUSE = 0, // Indicates that variables are to be fetched if they already exist or // otherwise created. AUTO_REUSE = 1