| @@ -65,7 +65,9 @@ namespace Tensorflow.Layers | |||
| variable_scope scope_context_manager = null; | |||
| if (built) | |||
| { | |||
| scope_context_manager = tf.variable_scope(_scope, | |||
| reuse: true, | |||
| auxiliary_name_scope: false); | |||
| } | |||
| else | |||
| { | |||
| @@ -181,7 +183,7 @@ namespace Tensorflow.Layers | |||
| return _current_scope.original_name_scope; | |||
| } | |||
| private void _set_scope(VariableScope scope = null) | |||
| protected void _set_scope(VariableScope scope = null) | |||
| { | |||
| if (_scope == null) | |||
| { | |||
| @@ -14,12 +14,17 @@ namespace Tensorflow | |||
| /// Basic LSTM recurrent network cell. | |||
| /// The implementation is based on: http://arxiv.org/abs/1409.2329. | |||
| /// </summary> | |||
| public class BasicLSTMCell : LayerRnnCell | |||
| public class BasicLstmCell : LayerRnnCell | |||
| { | |||
| int _num_units; | |||
| float _forget_bias; | |||
| bool _state_is_tuple; | |||
| IActivation _activation; | |||
| LSTMStateTuple _state; | |||
| VariableV1 _kernel; | |||
| VariableV1 _bias; | |||
| string _WEIGHTS_VARIABLE_NAME = "kernel"; | |||
| string _BIAS_VARIABLE_NAME = "bias"; | |||
| /// <summary> | |||
| /// Initialize the basic LSTM cell. | |||
| @@ -31,7 +36,7 @@ namespace Tensorflow | |||
| /// <param name="reuse"></param> | |||
| /// <param name="name"></param> | |||
| /// <param name="dtype"></param> | |||
| public BasicLSTMCell(int num_units, float forget_bias = 1.0f, bool state_is_tuple = true, | |||
| public BasicLstmCell(int num_units, float forget_bias = 1.0f, bool state_is_tuple = true, | |||
| IActivation activation = null, bool? reuse = null, string name = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: reuse, name: name, dtype: dtype) | |||
| { | |||
| @@ -44,13 +49,123 @@ namespace Tensorflow | |||
| _activation = tf.nn.tanh(); | |||
| } | |||
| public LSTMStateTuple state_size | |||
| protected override void build(TensorShape input_shape) | |||
| { | |||
| var input_depth = input_shape.dims.Last(); | |||
| var h_depth = _num_units; | |||
| _kernel = add_weight(_WEIGHTS_VARIABLE_NAME, | |||
| shape: new[] { input_depth + h_depth, 4 * _num_units }); | |||
| _bias = add_weight(_BIAS_VARIABLE_NAME, | |||
| shape: new[] { 4 * _num_units }, | |||
| initializer: tf.zeros_initializer); | |||
| built = true; | |||
| } | |||
| public Tensor[] __call__(Tensor inputs, LSTMStateTuple state) | |||
| { | |||
| _state = state; | |||
| return base.__call__(inputs); | |||
| } | |||
| /// <summary> | |||
| /// Long short-term memory cell (LSTM). | |||
| /// </summary> | |||
| /// <param name="inputs"></param> | |||
| /// <param name="training"></param> | |||
| /// <param name="state"></param> | |||
| /// <returns></returns> | |||
| protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) | |||
| { | |||
| var one = constant_op.constant(1, dtype: dtypes.int32); | |||
| // Parameters of gates are concatenated into one multiply for efficiency. | |||
| Tensor c = null; | |||
| Tensor h = null; | |||
| if(_state_is_tuple) | |||
| (c, h) = ((Tensor)_state.c, (Tensor)_state.h); | |||
| else | |||
| { | |||
| // array_ops.split(value: state, num_or_size_splits: 2, axis: one); | |||
| throw new NotImplementedException("BasicLstmCell call"); | |||
| } | |||
| var gate_inputs = math_ops.matmul(array_ops.concat(new[] { inputs, h }, 1), _kernel as RefVariable); | |||
| gate_inputs = nn_ops.bias_add(gate_inputs, _bias as RefVariable); | |||
| // i = input_gate, j = new_input, f = forget_gate, o = output_gate | |||
| var tensors = array_ops.split(value: gate_inputs, num_or_size_splits: 4, axis: one); | |||
| var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]); | |||
| var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype); | |||
| // Note that using `add` and `multiply` instead of `+` and `*` gives a | |||
| // performance improvement. So using those at the cost of readability. | |||
| var new_c = gen_math_ops.add( | |||
| math_ops.multiply(c, math_ops.sigmoid(gen_math_ops.add(f, forget_bias_tensor))), | |||
| math_ops.multiply(math_ops.sigmoid(i), _activation.Activate(j))); | |||
| var new_h = math_ops.multiply(_activation.Activate(new_c), math_ops.sigmoid(o)); | |||
| if (_state_is_tuple) | |||
| return new[] { new_c, new_h }; | |||
| else | |||
| return new[] { array_ops.concat(new[] { new_c, new_h }, 1) }; | |||
| } | |||
| public override object get_initial_state(Tensor inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) | |||
| { | |||
| if (inputs != null) | |||
| throw new NotImplementedException("get_initial_state input is not null"); | |||
| return zero_state(batch_size, dtype); | |||
| } | |||
| /// <summary> | |||
| /// Return zero-filled state tensor(s). | |||
| /// </summary> | |||
| /// <param name="batch_size"></param> | |||
| /// <param name="dtype"></param> | |||
| /// <returns></returns> | |||
| private LSTMStateTuple zero_state(Tensor batch_size, TF_DataType dtype) | |||
| { | |||
| LSTMStateTuple output = null; | |||
| tf_with(ops.name_scope($"{GetType().Name}ZeroState", values: new { batch_size }), delegate | |||
| { | |||
| output = _zero_state_tensors(state_size, batch_size, dtype); | |||
| }); | |||
| return output; | |||
| } | |||
| private LSTMStateTuple _zero_state_tensors(object state_size, Tensor batch_size, TF_DataType dtype) | |||
| { | |||
| if (state_size is LSTMStateTuple state_size_tuple) | |||
| { | |||
| var outputs = state_size_tuple.Flatten() | |||
| .Select(x => (int)x) | |||
| .Select(s => | |||
| { | |||
| var c = rnn_cell_impl._concat(batch_size, s); | |||
| var size = array_ops.zeros(c, dtype: dtype); | |||
| var c_static = rnn_cell_impl._concat(batch_size, s, @static: true); | |||
| size.set_shape(c_static); | |||
| return size; | |||
| }).ToArray(); | |||
| return new LSTMStateTuple(outputs[0], outputs[1]); | |||
| } | |||
| throw new NotImplementedException("_zero_state_tensors"); | |||
| } | |||
| public override object state_size | |||
| { | |||
| get | |||
| { | |||
| return _state_is_tuple ? | |||
| new LSTMStateTuple(_num_units, _num_units) : | |||
| (LSTMStateTuple)(2 * _num_units); | |||
| if (_state_is_tuple) | |||
| return new LSTMStateTuple(_num_units, _num_units); | |||
| else | |||
| return 2 * _num_units; | |||
| } | |||
| } | |||
| } | |||
| @@ -26,7 +26,7 @@ namespace Tensorflow | |||
| int _num_units; | |||
| Func<Tensor, string, Tensor> _activation; | |||
| public override LSTMStateTuple state_size => _num_units; | |||
| public override object state_size => _num_units; | |||
| public override int output_size => _num_units; | |||
| public VariableV1 _kernel; | |||
| string _WEIGHTS_VARIABLE_NAME = "kernel"; | |||
| @@ -12,15 +12,10 @@ namespace Tensorflow.Operations | |||
| /// | |||
| /// Only used when `state_is_tuple=True`. | |||
| /// </summary> | |||
| public class LSTMStateTuple | |||
| public class LSTMStateTuple : ICanBeFlattened | |||
| { | |||
| int c; | |||
| int h; | |||
| public LSTMStateTuple(int c) | |||
| { | |||
| this.c = c; | |||
| } | |||
| public object c; | |||
| public object h; | |||
| public LSTMStateTuple(int c, int h) | |||
| { | |||
| @@ -28,14 +23,13 @@ namespace Tensorflow.Operations | |||
| this.h = h; | |||
| } | |||
| public static implicit operator int(LSTMStateTuple tuple) | |||
| public LSTMStateTuple(Tensor c, Tensor h) | |||
| { | |||
| return tuple.c; | |||
| this.c = c; | |||
| this.h = h; | |||
| } | |||
| public static implicit operator LSTMStateTuple(int c) | |||
| { | |||
| return new LSTMStateTuple(c); | |||
| } | |||
| public object[] Flatten() | |||
| => new[] { c, h }; | |||
| } | |||
| } | |||
| @@ -49,7 +49,7 @@ namespace Tensorflow | |||
| /// difference between TF and Keras RNN cell. | |||
| /// </summary> | |||
| protected bool _is_tf_rnn_cell = false; | |||
| public virtual LSTMStateTuple state_size { get; } | |||
| public virtual object state_size { get; } | |||
| public virtual int output_size { get; } | |||
| @@ -64,7 +64,7 @@ namespace Tensorflow | |||
| _is_tf_rnn_cell = true; | |||
| } | |||
| public virtual Tensor get_initial_state(Tensor inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) | |||
| public virtual object get_initial_state(Tensor inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) | |||
| { | |||
| if (inputs != null) | |||
| throw new NotImplementedException("get_initial_state input is not null"); | |||
| @@ -78,11 +78,10 @@ namespace Tensorflow | |||
| /// <param name="batch_size"></param> | |||
| /// <param name="dtype"></param> | |||
| /// <returns></returns> | |||
| public Tensor zero_state(Tensor batch_size, TF_DataType dtype) | |||
| private Tensor zero_state(Tensor batch_size, TF_DataType dtype) | |||
| { | |||
| Tensor output = null; | |||
| var state_size = this.state_size; | |||
| tf_with(ops.name_scope($"{this.GetType().Name}ZeroState", values: new { batch_size }), delegate | |||
| tf_with(ops.name_scope($"{GetType().Name}ZeroState", values: new { batch_size }), delegate | |||
| { | |||
| output = _zero_state_tensors(state_size, batch_size, dtype); | |||
| }); | |||
| @@ -90,20 +89,25 @@ namespace Tensorflow | |||
| return output; | |||
| } | |||
| private Tensor _zero_state_tensors(int state_size, Tensor batch_size, TF_DataType dtype) | |||
| private Tensor _zero_state_tensors(object state_size, Tensor batch_size, TF_DataType dtype) | |||
| { | |||
| var output = nest.map_structure(s => | |||
| if(state_size is int state_size_int) | |||
| { | |||
| var c = rnn_cell_impl._concat(batch_size, s); | |||
| var size = array_ops.zeros(c, dtype: dtype); | |||
| var output = nest.map_structure(s => | |||
| { | |||
| var c = rnn_cell_impl._concat(batch_size, s); | |||
| var size = array_ops.zeros(c, dtype: dtype); | |||
| var c_static = rnn_cell_impl._concat(batch_size, s, @static: true); | |||
| size.set_shape(c_static); | |||
| var c_static = rnn_cell_impl._concat(batch_size, s, @static: true); | |||
| size.set_shape(c_static); | |||
| return size; | |||
| }, state_size); | |||
| return size; | |||
| }, state_size_int); | |||
| return output; | |||
| return output; | |||
| } | |||
| throw new NotImplementedException("_zero_state_tensors"); | |||
| } | |||
| } | |||
| } | |||
| @@ -29,8 +29,8 @@ namespace Tensorflow.Operations | |||
| /// <summary> | |||
| /// Creates a bidirectional recurrent neural network. | |||
| /// </summary> | |||
| public static void static_bidirectional_rnn(BasicLSTMCell cell_fw, | |||
| BasicLSTMCell cell_bw, | |||
| public static (Tensor[], LSTMStateTuple, LSTMStateTuple) static_bidirectional_rnn(BasicLstmCell cell_fw, | |||
| BasicLstmCell cell_bw, | |||
| Tensor[] inputs, | |||
| Tensor initial_state_fw = null, | |||
| Tensor initial_state_bw = null, | |||
| @@ -41,12 +41,17 @@ namespace Tensorflow.Operations | |||
| if (inputs == null || inputs.Length == 0) | |||
| throw new ValueError("inputs must not be empty"); | |||
| Tensor[] output_fw = null; | |||
| Tensor[] output_bw = null; | |||
| LSTMStateTuple output_state_fw = null; | |||
| LSTMStateTuple output_state_bw = null; | |||
| tf_with(tf.variable_scope(scope ?? "bidirectional_rnn"), delegate | |||
| { | |||
| // Forward direction | |||
| tf_with(tf.variable_scope("fw"), fw_scope => | |||
| { | |||
| static_rnn( | |||
| (output_fw, output_state_fw) = static_rnn( | |||
| cell_fw, | |||
| inputs, | |||
| initial_state_fw, | |||
| @@ -54,16 +59,48 @@ namespace Tensorflow.Operations | |||
| sequence_length, | |||
| scope: fw_scope); | |||
| }); | |||
| // backward direction | |||
| tf_with(tf.variable_scope("bw"), bw_scope => | |||
| { | |||
| var reversed_inputs = _reverse_seq(inputs, sequence_length); | |||
| (output_bw, output_state_bw) = static_rnn( | |||
| cell_bw, | |||
| reversed_inputs, | |||
| initial_state_bw, | |||
| dtype, | |||
| sequence_length, | |||
| scope: bw_scope); | |||
| }); | |||
| }); | |||
| output_bw = _reverse_seq(output_bw, sequence_length); | |||
| var flat_outputs = zip(output_fw, output_bw) | |||
| .Select(x => array_ops.concat(new[] { x.Item1, x.Item2 }, 1)) | |||
| .ToArray(); | |||
| return (flat_outputs, output_state_fw, output_state_bw); | |||
| } | |||
| public static void static_rnn(BasicLSTMCell cell, | |||
| private static Tensor[] _reverse_seq(Tensor[] input_seq, Tensor lengths) | |||
| { | |||
| if (lengths == null) | |||
| return input_seq.Reverse().ToArray(); | |||
| throw new NotImplementedException("_reverse_seq"); | |||
| } | |||
| public static (Tensor[], LSTMStateTuple) static_rnn(BasicLstmCell cell, | |||
| Tensor[] inputs, | |||
| Tensor initial_state, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| Tensor sequence_length = null, | |||
| VariableScope scope = null) | |||
| { | |||
| List<Tensor> outputs = new List<Tensor>(); | |||
| object state = null; | |||
| // Create a new scope in which the caching device is either | |||
| // determined by the parent scope, or is set to place the cached | |||
| // Variable using the same placement as for the rest of the RNN. | |||
| @@ -73,12 +110,12 @@ namespace Tensorflow.Operations | |||
| throw new NotImplementedException("static_rnn"); | |||
| }); | |||
| else | |||
| tf_with(tf.variable_scope(scope), varscope => | |||
| tf_with(tf.variable_scope(scope), scope1 => | |||
| { | |||
| Dimension fixed_batch_size = null; | |||
| Dimension batch_size = null; | |||
| Tensor batch_size_tensor = null; | |||
| VariableScope varscope = scope1; | |||
| // Obtain the first sequence of the input | |||
| var first_input = inputs[0]; | |||
| if (first_input.TensorShape.rank != 1) | |||
| @@ -108,14 +145,31 @@ namespace Tensorflow.Operations | |||
| else | |||
| batch_size_tensor = array_ops.shape(first_input)[0]; | |||
| Tensor state = null; | |||
| if (initial_state != null) | |||
| state = initial_state; | |||
| else | |||
| { | |||
| cell.get_initial_state(batch_size: batch_size_tensor, dtype: dtype); | |||
| state = cell.get_initial_state(batch_size: batch_size_tensor, dtype: dtype); | |||
| } | |||
| Tensor output = null; | |||
| if (state is LSTMStateTuple state_tuple) | |||
| { | |||
| foreach (var (time, input_) in enumerate(inputs)) | |||
| { | |||
| if (time > 0) | |||
| varscope.reuse_variables(); | |||
| if (sequence_length != null) | |||
| throw new NotImplementedException("static_rnn"); | |||
| var results = cell.__call__(input_, state_tuple); | |||
| (output, state_tuple) = (results[1], new LSTMStateTuple(results[0], results[1])); | |||
| outputs.Add(output); | |||
| } | |||
| } | |||
| }); | |||
| return (outputs.ToArray(), state as LSTMStateTuple); | |||
| } | |||
| public static (Tensor, Tensor) dynamic_rnn(RnnCell cell, Tensor inputs_tensor, | |||
| @@ -145,7 +199,7 @@ namespace Tensorflow.Operations | |||
| if (initial_state != null) | |||
| state = initial_state; | |||
| else | |||
| state = cell.get_initial_state(batch_size: batch_size, dtype: dtype); | |||
| state = cell.get_initial_state(batch_size: batch_size, dtype: dtype) as Tensor; | |||
| var inputs = nest.pack_sequence_as(structure: inputs_tensor, flat_sequence: flat_input); | |||
| @@ -604,6 +604,11 @@ namespace Tensorflow | |||
| return gen_array_ops.concat_v2(values, axis, name: name); | |||
| } | |||
| public static Tensor concat(Tensor[] values, Tensor axis, string name = "concat") | |||
| { | |||
| return gen_array_ops.concat_v2(values, axis, name: name); | |||
| } | |||
| public static Tensor concat(object[] values, int axis, string name = "concat") | |||
| { | |||
| return gen_array_ops.concat_v2(values, axis, name: name); | |||
| @@ -629,6 +634,16 @@ namespace Tensorflow | |||
| }); | |||
| } | |||
| public static Tensor[] split(Tensor value, int num_or_size_splits, Tensor axis, | |||
| string name = "split") | |||
| { | |||
| var size_splits = ops.convert_to_tensor(num_or_size_splits); | |||
| return gen_array_ops.split(axis: axis, | |||
| num_split: num_or_size_splits, | |||
| value: value, | |||
| name: name); | |||
| } | |||
| public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null) | |||
| => gen_array_ops.slice(input, begin, size, name: name); | |||
| @@ -47,7 +47,7 @@ namespace Tensorflow | |||
| /// <param name="axis"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public static Tensor concat_v2<T>(T[] values, int axis, string name = null) | |||
| public static Tensor concat_v2<T, Ta>(T[] values, Ta axis, string name = null) | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("ConcatV2", name: name, args: new { values, axis }); | |||
| @@ -5,7 +5,7 @@ | |||
| <AssemblyName>TensorFlow.NET</AssemblyName> | |||
| <RootNamespace>Tensorflow</RootNamespace> | |||
| <TargetTensorFlow>1.14.1</TargetTensorFlow> | |||
| <Version>0.12.1</Version> | |||
| <Version>0.13.0</Version> | |||
| <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | |||
| <Company>SciSharp STACK</Company> | |||
| <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | |||
| @@ -18,14 +18,16 @@ | |||
| <Description>Google's TensorFlow full binding in .NET Standard. | |||
| Building, training and infering deep learning models. | |||
| https://tensorflownet.readthedocs.io</Description> | |||
| <AssemblyVersion>0.12.1.0</AssemblyVersion> | |||
| <PackageReleaseNotes>Changes since v0.11.0: | |||
| <AssemblyVersion>0.13.0.0</AssemblyVersion> | |||
| <PackageReleaseNotes>Changes since v0.12.0: | |||
| 1: Add ICanBeFlattened for nest.flatten2. | |||
| 2: Complete the WhileContext. | |||
| 3: Add tf.nn.rnn_cell.BasicRNNCell and tf.nn.dynamic_rnn. | |||
| 4: Add EstimatorSpec.</PackageReleaseNotes> | |||
| 4: Add EstimatorSpec. | |||
| 5: Add rnn.static_rnn. | |||
| 6: Add array_grad._SplitGrad().</PackageReleaseNotes> | |||
| <LangVersion>7.3</LangVersion> | |||
| <FileVersion>0.12.1.0</FileVersion> | |||
| <FileVersion>0.13.0.0</FileVersion> | |||
| <PackageLicenseFile>LICENSE</PackageLicenseFile> | |||
| <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | |||
| <SignAssembly>true</SignAssembly> | |||
| @@ -7,20 +7,6 @@ namespace Tensorflow | |||
| { | |||
| public partial class Tensor | |||
| { | |||
| /// <summary> | |||
| /// Issue unresolved, will cause name_scope problem. | |||
| /// </summary> | |||
| /// <param name="scalar"></param> | |||
| /*public static implicit operator Tensor(double scalar) | |||
| { | |||
| return constant_op.constant(scalar); | |||
| }*/ | |||
| /*public static implicit operator Tensor(int scalar) | |||
| { | |||
| return constant_op.constant(scalar); | |||
| }*/ | |||
| public static implicit operator IntPtr(Tensor tensor) | |||
| { | |||
| if (tensor._handle == IntPtr.Zero) | |||
| @@ -526,14 +526,6 @@ namespace Tensorflow.Util | |||
| return pack_sequence_as(structure, mapped_flat_structure) as Tensor; | |||
| } | |||
| public static Tensor map_structure2<T>(Func<T, Tensor> func, T structure) | |||
| { | |||
| var flat_structure = flatten(structure); | |||
| var mapped_flat_structure = flat_structure.Select(func).ToList(); | |||
| return pack_sequence_as(structure, mapped_flat_structure) as Tensor; | |||
| } | |||
| /// <summary> | |||
| /// Same as map_structure, but with only one structure (no combining of multiple structures) | |||
| /// </summary> | |||
| @@ -74,5 +74,10 @@ namespace Tensorflow | |||
| aggregation: aggregation) as RefVariable; | |||
| }); | |||
| } | |||
| public void reuse_variables() | |||
| { | |||
| _reuse = _ReuseMode.AUTO_REUSE; | |||
| } | |||
| } | |||
| } | |||
| @@ -5,6 +5,7 @@ | |||
| /// </summary> | |||
| public enum _ReuseMode | |||
| { | |||
| NOT_REUSE = 0, | |||
| // Indicates that variables are to be fetched if they already exist or | |||
| // otherwise created. | |||
| AUTO_REUSE = 1 | |||