| @@ -346,4 +346,5 @@ Get started with the implementation: | |||
| } | |||
| ``` | |||
|  | |||
|  | |||
| @@ -16,9 +16,11 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Operations; | |||
| using Tensorflow.Operations.Activation; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Python; | |||
| namespace Tensorflow | |||
| @@ -68,6 +70,33 @@ namespace Tensorflow | |||
| return nn_ops.dropout_v2(x, rate: rate_tensor, noise_shape: noise_shape, seed: seed, name: name); | |||
| } | |||
| /// <summary> | |||
| /// Creates a recurrent neural network specified by RNNCell `cell`. | |||
| /// </summary> | |||
| /// <param name="cell">An instance of RNNCell.</param> | |||
| /// <param name="inputs">The RNN inputs.</param> | |||
| /// <param name="dtype"></param> | |||
| /// <param name="swap_memory"></param> | |||
| /// <param name="time_major"></param> | |||
| /// <returns>A pair (outputs, state)</returns> | |||
| public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs, TF_DataType dtype = TF_DataType.DtInvalid, | |||
| bool swap_memory = false, bool time_major = false) | |||
| { | |||
| with(variable_scope("rnn"), scope => | |||
| { | |||
| VariableScope varscope = scope; | |||
| var flat_input = nest.flatten(inputs); | |||
| if (!time_major) | |||
| { | |||
| flat_input = flat_input.Select(x => ops.convert_to_tensor(x)).ToList(); | |||
| //flat_input = flat_input.Select(x => _transpose_batch_time(x)).ToList(); | |||
| } | |||
| }); | |||
| throw new NotImplementedException(""); | |||
| } | |||
| public static (Tensor, Tensor) moments(Tensor x, | |||
| int[] axes, | |||
| string name = null, | |||
| @@ -1,10 +1,48 @@ | |||
| using System; | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Operations.Activation; | |||
| namespace Tensorflow | |||
| { | |||
| public class BasicRNNCell | |||
| public class BasicRNNCell : LayerRNNCell | |||
| { | |||
| int _num_units; | |||
| Func<Tensor, string, Tensor> _activation; | |||
| public BasicRNNCell(int num_units, | |||
| Func<Tensor, string, Tensor> activation = null, | |||
| bool? reuse = null, | |||
| string name = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: reuse, | |||
| name: name, | |||
| dtype: dtype) | |||
| { | |||
| // Inputs must be 2-dimensional. | |||
| input_spec = new InputSpec(ndim: 2); | |||
| _num_units = num_units; | |||
| if (activation == null) | |||
| _activation = math_ops.tanh; | |||
| else | |||
| _activation = activation; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,33 @@ | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class LayerRNNCell : RNNCell | |||
| { | |||
| public LayerRNNCell(bool? _reuse = null, | |||
| string name = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: _reuse, | |||
| name: name, | |||
| dtype: dtype) | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,63 @@ | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// Abstract object representing an RNN cell. | |||
| /// | |||
| /// Every `RNNCell` must have the properties below and implement `call` with | |||
| /// the signature `(output, next_state) = call(input, state)`. The optional | |||
| /// third input argument, `scope`, is allowed for backwards compatibility | |||
| /// purposes; but should be left off for new subclasses. | |||
| /// | |||
| /// This definition of cell differs from the definition used in the literature. | |||
| /// In the literature, 'cell' refers to an object with a single scalar output. | |||
| /// This definition refers to a horizontal array of such units. | |||
| /// | |||
| /// An RNN cell, in the most abstract setting, is anything that has | |||
| /// a state and performs some operation that takes a matrix of inputs. | |||
| /// This operation results in an output matrix with `self.output_size` columns. | |||
| /// If `self.state_size` is an integer, this operation also results in a new | |||
| /// state matrix with `self.state_size` columns. If `self.state_size` is a | |||
| /// (possibly nested tuple of) TensorShape object(s), then it should return a | |||
| /// matching structure of Tensors having shape `[batch_size].concatenate(s)` | |||
| /// for each `s` in `self.batch_size`. | |||
| /// </summary> | |||
| public abstract class RNNCell : Layers.Layer | |||
| { | |||
| /// <summary> | |||
| /// Attribute that indicates whether the cell is a TF RNN cell, due the slight | |||
| /// difference between TF and Keras RNN cell. | |||
| /// </summary> | |||
| protected bool _is_tf_rnn_cell = false; | |||
| public RNNCell(bool trainable = true, | |||
| string name = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| bool? _reuse = null) : base(trainable: trainable, | |||
| name: name, | |||
| dtype: dtype, | |||
| _reuse: _reuse) | |||
| { | |||
| _is_tf_rnn_cell = true; | |||
| } | |||
| } | |||
| } | |||
| @@ -551,6 +551,9 @@ namespace Tensorflow | |||
| }); | |||
| } | |||
| public static Tensor tanh(Tensor x, string name = null) | |||
| => gen_math_ops.tanh(x, name); | |||
| public static Tensor truediv(Tensor x, Tensor y, string name = null) | |||
| => _truediv_python3(x, y, name); | |||
| @@ -7,8 +7,6 @@ namespace Tensorflow.Operations | |||
| public class rnn_cell_impl | |||
| { | |||
| public BasicRNNCell BasicRNNCell(int num_units) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| => new BasicRNNCell(num_units); | |||
| } | |||
| } | |||
| @@ -214,14 +214,14 @@ namespace Tensorflow.Util | |||
| //# See the swig file (util.i) for documentation. | |||
| //flatten = _pywrap_tensorflow.Flatten | |||
| public static List<object> flatten(object structure) | |||
| public static List<T> flatten<T>(T structure) | |||
| { | |||
| var list = new List<object>(); | |||
| var list = new List<T>(); | |||
| _flatten_recursive(structure, list); | |||
| return list; | |||
| } | |||
| private static void _flatten_recursive(object obj, List<object> list) | |||
| private static void _flatten_recursive<T>(T obj, List<T> list) | |||
| { | |||
| if (obj is string) | |||
| { | |||
| @@ -232,7 +232,7 @@ namespace Tensorflow.Util | |||
| { | |||
| var dict = obj as IDictionary; | |||
| foreach (var key in _sorted(dict)) | |||
| _flatten_recursive(dict[key], list); | |||
| _flatten_recursive((T)dict[key], list); | |||
| return; | |||
| } | |||
| if (obj is NDArray) | |||
| @@ -244,7 +244,7 @@ namespace Tensorflow.Util | |||
| { | |||
| var structure = obj as IEnumerable; | |||
| foreach (var child in structure) | |||
| _flatten_recursive(child, list); | |||
| _flatten_recursive((T)child, list); | |||
| return; | |||
| } | |||
| list.Add(obj); | |||
| @@ -25,14 +25,12 @@ using static Tensorflow.Python; | |||
| namespace TensorFlowNET.Examples.ImageProcess | |||
| { | |||
| /// <summary> | |||
| /// Convolutional Neural Network classifier for Hand Written Digits | |||
| /// CNN architecture with two convolutional layers, followed by two fully-connected layers at the end. | |||
| /// Use Stochastic Gradient Descent (SGD) optimizer. | |||
| /// http://www.easy-tensorflow.com/tf-tutorials/convolutional-neural-nets-cnns/cnn1 | |||
| /// Recurrent Neural Network for handwritten digits MNIST. | |||
| /// https://medium.com/machine-learning-algorithms/mnist-using-recurrent-neural-network-2d070a5915a2 | |||
| /// </summary> | |||
| public class DigitRecognitionRNN : IExample | |||
| { | |||
| public bool Enabled { get; set; } = false; | |||
| public bool Enabled { get; set; } = true; | |||
| public bool IsImportingGraph { get; set; } = false; | |||
| public string Name => "MNIST RNN"; | |||
| @@ -84,6 +82,7 @@ namespace TensorFlowNET.Examples.ImageProcess | |||
| var X = tf.placeholder(tf.float32, new[] { -1, n_steps, n_inputs }); | |||
| var y = tf.placeholder(tf.int32, new[] { -1 }); | |||
| var cell = tf.nn.rnn_cell.BasicRNNCell(num_units: n_neurons); | |||
| var (output, state) = tf.nn.dynamic_rnn(cell, X, dtype: tf.float32); | |||
| return graph; | |||
| } | |||
| @@ -154,6 +153,7 @@ namespace TensorFlowNET.Examples.ImageProcess | |||
| print("Size of:"); | |||
| print($"- Training-set:\t\t{len(mnist.train.data)}"); | |||
| print($"- Validation-set:\t{len(mnist.validation.data)}"); | |||
| print($"- Test-set:\t\t{len(mnist.test.data)}"); | |||
| } | |||
| public Graph ImportGraph() => throw new NotImplementedException(); | |||
| @@ -78,8 +78,8 @@ namespace TensorFlowNET.UnitTest.nest_test | |||
| self.assertEqual((((restructured_from_flat[1] as object[])[0] as object[])[0] as Hashtable)["y"], 0); | |||
| self.assertEqual(new List<object> { 5 }, nest.flatten(5)); | |||
| flat = nest.flatten(np.array(new[] { 5 })); | |||
| self.assertEqual(new object[] { np.array(new int[] { 5 }) }, flat); | |||
| var flat1 = nest.flatten(np.array(new[] { 5 })); | |||
| self.assertEqual(new object[] { np.array(new int[] { 5 }) }, flat1); | |||
| self.assertEqual("a", nest.pack_sequence_as(5, new List<object> { "a" })); | |||
| self.assertEqual(np.array(new[] { 5 }), | |||