| @@ -27,7 +27,8 @@ namespace Tensorflow | |||||
| public IInitializer zeros_initializer => new Zeros(); | public IInitializer zeros_initializer => new Zeros(); | ||||
| public IInitializer ones_initializer => new Ones(); | public IInitializer ones_initializer => new Ones(); | ||||
| public IInitializer glorot_uniform_initializer => new GlorotUniform(); | public IInitializer glorot_uniform_initializer => new GlorotUniform(); | ||||
| public IInitializer uniform_initializer => new RandomUniform(); | |||||
| public IInitializer random_uniform_initializer => new RandomUniform(); | |||||
| public IInitializer orthogonal_initializer => new Orthogonal(); | |||||
| public variable_scope variable_scope(string name, | public variable_scope variable_scope(string name, | ||||
| string default_name = null, | string default_name = null, | ||||
| @@ -20,7 +20,9 @@ namespace Tensorflow.Keras | |||||
| return results[0]; | return results[0]; | ||||
| } | } | ||||
| throw new NotImplementedException(""); | |||||
| var _op = tf.OpDefLib._apply_op_helper("Relu", name: name, args: new { features }); | |||||
| return _op.output; | |||||
| }; | }; | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,26 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Operations; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Keras | |||||
| { | |||||
| public partial class Activations | |||||
| { | |||||
| public Activation Sigmoid = (features, name) => | |||||
| { | |||||
| if (tf.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
| "Sigmoid", name, | |||||
| null, | |||||
| features); | |||||
| return results[0]; | |||||
| } | |||||
| throw new NotImplementedException(""); | |||||
| }; | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,26 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Operations; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Keras | |||||
| { | |||||
| public partial class Activations | |||||
| { | |||||
| public Activation Tanh = (features, name) => | |||||
| { | |||||
| if (tf.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
| "Tanh", name, | |||||
| null, | |||||
| features); | |||||
| return results[0]; | |||||
| } | |||||
| throw new NotImplementedException(""); | |||||
| }; | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,15 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class EmbeddingArgs : LayerArgs | |||||
| { | |||||
| public int InputDim { get; set; } | |||||
| public int OutputDim { get; set; } | |||||
| public bool MaskZero { get; set; } | |||||
| public int InputLength { get; set; } = -1; | |||||
| public IInitializer EmbeddingsInitializer { get; set; } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,26 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class LSTMArgs : RNNArgs | |||||
| { | |||||
| public int Units { get; set; } | |||||
| public Activation Activation { get; set; } | |||||
| public Activation RecurrentActivation { get; set; } | |||||
| public IInitializer KernelInitializer { get; set; } | |||||
| public IInitializer RecurrentInitializer { get; set; } | |||||
| public IInitializer BiasInitializer { get; set; } | |||||
| public bool UnitForgetBias { get; set; } | |||||
| public float Dropout { get; set; } | |||||
| public float RecurrentDropout { get; set; } | |||||
| public int Implementation { get; set; } | |||||
| public bool ReturnSequences { get; set; } | |||||
| public bool ReturnState { get; set; } | |||||
| public bool GoBackwards { get; set; } | |||||
| public bool Stateful { get; set; } | |||||
| public bool TimeMajor { get; set; } | |||||
| public bool Unroll { get; set; } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,10 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class LSTMCellArgs : LayerArgs | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,10 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class RNNArgs : LayerArgs | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -26,17 +26,22 @@ namespace Tensorflow.Keras.Engine | |||||
| public int? ndim; | public int? ndim; | ||||
| public int? min_ndim; | public int? min_ndim; | ||||
| Dictionary<int, int> axes; | Dictionary<int, int> axes; | ||||
| TensorShape shape; | |||||
| public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid, | public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| int? ndim = null, | int? ndim = null, | ||||
| int? min_ndim = null, | int? min_ndim = null, | ||||
| Dictionary<int, int> axes = null) | |||||
| Dictionary<int, int> axes = null, | |||||
| TensorShape shape = null) | |||||
| { | { | ||||
| this.ndim = ndim; | this.ndim = ndim; | ||||
| if (axes == null) | if (axes == null) | ||||
| axes = new Dictionary<int, int>(); | axes = new Dictionary<int, int>(); | ||||
| this.axes = axes; | this.axes = axes; | ||||
| this.min_ndim = min_ndim; | this.min_ndim = min_ndim; | ||||
| this.shape = shape; | |||||
| if (ndim == null && shape != null) | |||||
| this.ndim = shape.ndim; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Layers; | using Tensorflow.Keras.Layers; | ||||
| using Tensorflow.Operations.Activation; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
| @@ -100,5 +101,46 @@ namespace Tensorflow.Keras.Engine | |||||
| _layers.Add(layer); | _layers.Add(layer); | ||||
| return layer; | return layer; | ||||
| } | } | ||||
| protected Layer LSTM(int units, | |||||
| Activation activation = null, | |||||
| Activation recurrent_activation = null, | |||||
| bool use_bias = true, | |||||
| IInitializer kernel_initializer = null, | |||||
| IInitializer recurrent_initializer = null, | |||||
| IInitializer bias_initializer = null, | |||||
| bool unit_forget_bias = true, | |||||
| float dropout = 0f, | |||||
| float recurrent_dropout = 0f, | |||||
| int implementation = 2, | |||||
| bool return_sequences = false, | |||||
| bool return_state = false, | |||||
| bool go_backwards = false, | |||||
| bool stateful = false, | |||||
| bool time_major = false, | |||||
| bool unroll = false) | |||||
| { | |||||
| var layer = new LSTM(new LSTMArgs | |||||
| { | |||||
| Units = units, | |||||
| Activation = activation ?? tf.keras.activations.Tanh, | |||||
| RecurrentActivation = recurrent_activation ?? tf.keras.activations.Sigmoid, | |||||
| KernelInitializer = kernel_initializer ?? tf.glorot_uniform_initializer, | |||||
| RecurrentInitializer = recurrent_initializer ?? tf.orthogonal_initializer, | |||||
| BiasInitializer = bias_initializer ?? tf.zeros_initializer, | |||||
| Dropout = dropout, | |||||
| RecurrentDropout = recurrent_dropout, | |||||
| Implementation = implementation, | |||||
| ReturnSequences = return_sequences, | |||||
| ReturnState = return_state, | |||||
| GoBackwards = go_backwards, | |||||
| Stateful = stateful, | |||||
| TimeMajor = time_major, | |||||
| Unroll = unroll | |||||
| }); | |||||
| _layers.Add(layer); | |||||
| return layer; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -144,7 +144,9 @@ namespace Tensorflow.Keras.Engine | |||||
| } | } | ||||
| // using var graph = tf.keras.backend.get_graph().as_default(); | // using var graph = tf.keras.backend.get_graph().as_default(); | ||||
| if (!inputs.IsEagerTensor) | |||||
| tf.Context.graph_mode(); | |||||
| tf_with(ops.name_scope(nameScope), scope => | tf_with(ops.name_scope(nameScope), scope => | ||||
| { | { | ||||
| if (!built) | if (!built) | ||||
| @@ -157,6 +159,8 @@ namespace Tensorflow.Keras.Engine | |||||
| _set_mask_metadata(inputs, outputs, null); | _set_mask_metadata(inputs, outputs, null); | ||||
| }); | }); | ||||
| tf.Context.eager_mode(); | |||||
| return outputs; | return outputs; | ||||
| } | } | ||||
| @@ -17,7 +17,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Layers; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
| { | { | ||||
| @@ -56,6 +56,8 @@ namespace Tensorflow.Keras.Engine | |||||
| } | } | ||||
| // Set metadata on outputs. | // Set metadata on outputs. | ||||
| var node_index = layer.InboundNodes.Count - 1; | |||||
| args.Outputs.KerasHistory.Add(layer); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -14,17 +14,22 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Layers; | using Tensorflow.Keras.Layers; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
| { | { | ||||
| public class Sequential : Model, ITensorFlowObject | |||||
| /// <summary> | |||||
| /// `Sequential` groups a linear stack of layers into a `tf.keras.Model`. | |||||
| /// `Sequential` provides training and inference features on this model. | |||||
| /// </summary> | |||||
| public class Sequential | |||||
| { | { | ||||
| #pragma warning disable CS0649 // Field 'Sequential._is_graph_network' is never assigned to, and will always have its default value false | |||||
| bool _is_graph_network; | bool _is_graph_network; | ||||
| #pragma warning restore CS0649 // Field 'Sequential._is_graph_network' is never assigned to, and will always have its default value false | |||||
| Tensor inputs; | |||||
| Tensor outputs; | Tensor outputs; | ||||
| bool computeOutputAndMaskJointly; | bool computeOutputAndMaskJointly; | ||||
| @@ -32,26 +37,24 @@ namespace Tensorflow.Keras.Engine | |||||
| TensorShape inferredInputShape; | TensorShape inferredInputShape; | ||||
| bool hasExplicitInputShape; | bool hasExplicitInputShape; | ||||
| TF_DataType inputDType; | TF_DataType inputDType; | ||||
| Layer[] layers; | |||||
| List<Layer> layers; | |||||
| public TensorShape output_shape => outputs.TensorShape; | |||||
| bool built = false; | |||||
| public Sequential(Layer[] layers = null, string name = null) | public Sequential(Layer[] layers = null, string name = null) | ||||
| : base(new ModelArgs { Name = name}) | |||||
| { | { | ||||
| this.layers = layers ?? new Layer[0]; | |||||
| SupportsMasking = true; | |||||
| this.layers = layers == null ? new List<Layer>() : layers.ToList(); | |||||
| // SupportsMasking = true; | |||||
| computeOutputAndMaskJointly = true; | computeOutputAndMaskJointly = true; | ||||
| autoTrackSubLayers = false; | autoTrackSubLayers = false; | ||||
| hasExplicitInputShape = false; | hasExplicitInputShape = false; | ||||
| _is_graph_network = false; | |||||
| } | } | ||||
| public void __enter__() | |||||
| public void add(Tensor tensor) | |||||
| { | { | ||||
| } | |||||
| public void add(Tensor layer) | |||||
| { | |||||
| var layer = tensor.KerasHistory[0]; | |||||
| add(layer); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -62,9 +65,9 @@ namespace Tensorflow.Keras.Engine | |||||
| { | { | ||||
| built = false; | built = false; | ||||
| var set_inputs = false; | var set_inputs = false; | ||||
| if(layers.Length == 0) | |||||
| if (layers.Count == 0) | |||||
| { | { | ||||
| if(layer is InputLayer) | |||||
| if (layer is InputLayer) | |||||
| { | { | ||||
| set_inputs = true; | set_inputs = true; | ||||
| } | } | ||||
| @@ -93,31 +96,33 @@ namespace Tensorflow.Keras.Engine | |||||
| } | } | ||||
| } | } | ||||
| else if (outputs != null) | |||||
| { | |||||
| outputs = layer.Apply(outputs); | |||||
| } | |||||
| if (set_inputs || _is_graph_network) | if (set_inputs || _is_graph_network) | ||||
| { | { | ||||
| _init_graph_network(inputs, outputs); | |||||
| } | } | ||||
| } | |||||
| public void __exit__() | |||||
| { | |||||
| } | |||||
| public void Dispose() | |||||
| { | |||||
| else | |||||
| { | |||||
| } | |||||
| } | } | ||||
| public void __init__() | |||||
| void _init_graph_network(Tensor inputs, Tensor outputs) | |||||
| { | { | ||||
| _is_graph_network = true; | |||||
| this.inputs = inputs; | |||||
| this.outputs = outputs; | |||||
| built = true; | |||||
| _map_graph_network(inputs, outputs); | |||||
| } | } | ||||
| public void __del__() | |||||
| void _map_graph_network(Tensor inputs, Tensor outputs) | |||||
| { | { | ||||
| layers.add(outputs.KerasHistory[0]); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -63,18 +63,9 @@ namespace Tensorflow | |||||
| var layer = new InputLayer(args); | var layer = new InputLayer(args); | ||||
| return layer.InboundNodes[0].Outputs[0]; | |||||
| return layer.InboundNodes[0].Outputs; | |||||
| } | } | ||||
| public static Embedding Embedding(int input_dim, | |||||
| int output_dim, | |||||
| IInitializer embeddings_initializer = null, | |||||
| bool mask_zero = false) | |||||
| => new Embedding(input_dim, | |||||
| output_dim, | |||||
| embeddings_initializer, | |||||
| mask_zero); | |||||
| public class LayersApi | public class LayersApi | ||||
| { | { | ||||
| public Layer Dense(int units, | public Layer Dense(int units, | ||||
| @@ -86,6 +77,30 @@ namespace Tensorflow | |||||
| Activation = activation ?? tf.keras.activations.Linear, | Activation = activation ?? tf.keras.activations.Linear, | ||||
| InputShape = input_shape | InputShape = input_shape | ||||
| }); | }); | ||||
| /// <summary> | |||||
| /// Turns positive integers (indexes) into dense vectors of fixed size. | |||||
| /// </summary> | |||||
| /// <param name="input_dim"></param> | |||||
| /// <param name="output_dim"></param> | |||||
| /// <param name="embeddings_initializer"></param> | |||||
| /// <param name="mask_zero"></param> | |||||
| /// <returns></returns> | |||||
| public Embedding Embedding(int input_dim, | |||||
| int output_dim, | |||||
| IInitializer embeddings_initializer = null, | |||||
| bool mask_zero = false, | |||||
| TensorShape input_shape = null, | |||||
| int input_length = -1) | |||||
| => new Embedding(new EmbeddingArgs | |||||
| { | |||||
| InputDim = input_dim, | |||||
| OutputDim = output_dim, | |||||
| MaskZero = mask_zero, | |||||
| InputShape = input_shape ?? input_length, | |||||
| InputLength = input_length, | |||||
| EmbeddingsInitializer = embeddings_initializer | |||||
| }); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -20,40 +20,37 @@ using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Keras.Layers | namespace Tensorflow.Keras.Layers | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Turns positive integers (indexes) into dense vectors of fixed size. | |||||
| /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding | |||||
| /// </summary> | |||||
| public class Embedding : Layer | public class Embedding : Layer | ||||
| { | { | ||||
| private int input_dim; | |||||
| private int output_dim; | |||||
| private bool mask_zero; | |||||
| public IVariableV1 embeddings; | |||||
| public IInitializer embeddings_initializer; | |||||
| int input_length; | |||||
| EmbeddingArgs args; | |||||
| int input_dim => args.InputDim; | |||||
| int output_dim => args.OutputDim; | |||||
| bool mask_zero => args.MaskZero; | |||||
| IVariableV1 embeddings; | |||||
| IInitializer embeddings_initializer; | |||||
| public Embedding(int input_dim, int output_dim, | |||||
| IInitializer embeddings_initializer = null, | |||||
| bool mask_zero = false, | |||||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||||
| int[] input_shape = null, | |||||
| int input_length = -1) : | |||||
| base(new LayerArgs | |||||
| { | |||||
| DType = dtype, | |||||
| InputShape = input_shape ?? new[] { input_length } | |||||
| }) | |||||
| public Embedding(EmbeddingArgs args) | |||||
| : base(args) | |||||
| { | { | ||||
| this.input_dim = input_dim; | |||||
| this.output_dim = output_dim; | |||||
| this.embeddings_initializer = embeddings_initializer == null ? tf.uniform_initializer : embeddings_initializer; | |||||
| this.mask_zero = mask_zero; | |||||
| this.args = args; | |||||
| if(args.InputShape == null) | |||||
| args.InputShape = args.InputLength; | |||||
| embeddings_initializer = embeddings_initializer ?? tf.random_uniform_initializer; | |||||
| SupportsMasking = mask_zero; | SupportsMasking = mask_zero; | ||||
| this.input_length = input_length; | |||||
| } | } | ||||
| protected override void build(TensorShape input_shape) | protected override void build(TensorShape input_shape) | ||||
| { | { | ||||
| embeddings = add_weight(shape: new int[] { input_dim, output_dim }, | |||||
| tf.Context.eager_mode(); | |||||
| embeddings = add_weight(shape: (input_dim, output_dim), | |||||
| initializer: embeddings_initializer, | initializer: embeddings_initializer, | ||||
| name: "embeddings"); | name: "embeddings"); | ||||
| tf.Context.graph_mode(); | |||||
| built = true; | built = true; | ||||
| } | } | ||||
| @@ -63,7 +60,7 @@ namespace Tensorflow.Keras.Layers | |||||
| if (dtype != tf.int32 && dtype != tf.int64) | if (dtype != tf.int32 && dtype != tf.int64) | ||||
| inputs = math_ops.cast(inputs, tf.int32); | inputs = math_ops.cast(inputs, tf.int32); | ||||
| var outputs = embedding_ops.embedding_lookup(embeddings, inputs[0]); | |||||
| var outputs = embedding_ops.embedding_lookup(embeddings, inputs); | |||||
| return outputs; | return outputs; | ||||
| } | } | ||||
| } | } | ||||
| @@ -53,6 +53,11 @@ namespace Tensorflow.Keras.Layers | |||||
| args.Name = prefix + '_' + tf.keras.backend.get_uid(prefix); | args.Name = prefix + '_' + tf.keras.backend.get_uid(prefix); | ||||
| } | } | ||||
| if(args.DType == TF_DataType.DtInvalid) | |||||
| { | |||||
| args.DType = args.InputTensor == null ? tf.float32 : args.InputTensor.dtype; | |||||
| } | |||||
| if (args.InputTensor == null) | if (args.InputTensor == null) | ||||
| { | { | ||||
| if(args.InputShape != null) | if(args.InputShape != null) | ||||
| @@ -72,7 +77,8 @@ namespace Tensorflow.Keras.Layers | |||||
| shape: BatchInputShape, | shape: BatchInputShape, | ||||
| dtype: DType, | dtype: DType, | ||||
| name: Name, | name: Name, | ||||
| sparse: args.Sparse); | |||||
| sparse: args.Sparse, | |||||
| ragged: args.Ragged); | |||||
| tf.Context.eager_mode(); | tf.Context.eager_mode(); | ||||
| isPlaceholder = true; | isPlaceholder = true; | ||||
| @@ -0,0 +1,37 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.Engine; | |||||
| namespace Tensorflow.Keras.Layers | |||||
| { | |||||
| /// <summary> | |||||
| /// Long Short-Term Memory layer - Hochreiter 1997. | |||||
| /// | |||||
| /// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) | |||||
| /// for details about the usage of RNN API. | |||||
| /// </summary> | |||||
| public class LSTM : RNN | |||||
| { | |||||
| LSTMArgs args; | |||||
| InputSpec[] state_spec; | |||||
| int units => args.Units; | |||||
| public LSTM(LSTMArgs args) : | |||||
| base(args) | |||||
| { | |||||
| this.args = args; | |||||
| state_spec = new[] { units, units } | |||||
| .Select(dim => new InputSpec(shape: (-1, dim))) | |||||
| .ToArray(); | |||||
| } | |||||
| protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||||
| { | |||||
| return base.call(inputs, is_training, state); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,20 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Runtime.CompilerServices; | |||||
| using System.Text; | |||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.Engine; | |||||
| namespace Tensorflow.Keras.Layers | |||||
| { | |||||
| public class LSTMCell : Layer | |||||
| { | |||||
| LSTMCellArgs args; | |||||
| public LSTMCell(LSTMCellArgs args) | |||||
| : base(args) | |||||
| { | |||||
| this.args = args; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,27 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.Engine; | |||||
| namespace Tensorflow.Keras.Layers | |||||
| { | |||||
| public class RNN : Layer | |||||
| { | |||||
| public RNN(RNNArgs args) | |||||
| : base(args) | |||||
| { | |||||
| } | |||||
| protected Tensor get_initial_state(Tensor inputs) | |||||
| { | |||||
| return _generate_zero_filled_state_for_cell(null, null); | |||||
| } | |||||
| Tensor _generate_zero_filled_state_for_cell(LSTMCell cell, Tensor batch_size) | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,14 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Operations.Initializers | |||||
| { | |||||
| public class Orthogonal : IInitializer | |||||
| { | |||||
| public Tensor Apply(InitializerArgs args) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -29,7 +29,7 @@ namespace Tensorflow.Operations.Initializers | |||||
| #pragma warning restore CS0649 // Field 'RandomUniform.maxval' is never assigned to, and will always have its default value 0 | #pragma warning restore CS0649 // Field 'RandomUniform.maxval' is never assigned to, and will always have its default value 0 | ||||
| private TF_DataType dtype; | private TF_DataType dtype; | ||||
| public RandomUniform(TF_DataType dtype = TF_DataType.DtInvalid) | |||||
| public RandomUniform(TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
| { | { | ||||
| this.dtype = dtype; | this.dtype = dtype; | ||||
| } | } | ||||
| @@ -37,7 +37,7 @@ namespace Tensorflow.Operations.Initializers | |||||
| public Tensor Apply(InitializerArgs args) | public Tensor Apply(InitializerArgs args) | ||||
| { | { | ||||
| if (args.DType == TF_DataType.DtInvalid) | if (args.DType == TF_DataType.DtInvalid) | ||||
| args.DType = this.dtype; | |||||
| args.DType = dtype; | |||||
| return random_ops.random_uniform(args.Shape, | return random_ops.random_uniform(args.Shape, | ||||
| minval: minval, | minval: minval, | ||||
| @@ -193,7 +193,7 @@ namespace Tensorflow | |||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| public static Tensor identity(Tensor input, string name = null) | public static Tensor identity(Tensor input, string name = null) | ||||
| { | { | ||||
| if (tf.Context.executing_eagerly()) | |||||
| if (tf.executing_eagerly()) | |||||
| { | { | ||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | ||||
| "Identity", name, | "Identity", name, | ||||
| @@ -140,7 +140,7 @@ namespace Tensorflow | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Tensor read_variable_op(Tensor resource, TF_DataType dtype, string name = null) | public static Tensor read_variable_op(Tensor resource, TF_DataType dtype, string name = null) | ||||
| { | { | ||||
| if (tf.Context.executing_eagerly()) | |||||
| if (tf.executing_eagerly()) | |||||
| { | { | ||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | ||||
| "ReadVariableOp", name, | "ReadVariableOp", name, | ||||
| @@ -5,7 +5,7 @@ | |||||
| <AssemblyName>TensorFlow.NET</AssemblyName> | <AssemblyName>TensorFlow.NET</AssemblyName> | ||||
| <RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
| <TargetTensorFlow>2.2.0</TargetTensorFlow> | <TargetTensorFlow>2.2.0</TargetTensorFlow> | ||||
| <Version>0.20.0-preview3</Version> | |||||
| <Version>0.20.0-preview4</Version> | |||||
| <LangVersion>8.0</LangVersion> | <LangVersion>8.0</LangVersion> | ||||
| <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | ||||
| <Company>SciSharp STACK</Company> | <Company>SciSharp STACK</Company> | ||||
| @@ -25,6 +25,7 @@ using System.Text; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using Tensorflow.Framework; | using Tensorflow.Framework; | ||||
| using Tensorflow.Keras.Engine; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -97,6 +98,8 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public SafeTensorHandleHandle EagerTensorHandle { get; set; } | public SafeTensorHandleHandle EagerTensorHandle { get; set; } | ||||
| public bool IsEagerTensor => this is EagerTensor; | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the shape of a tensor. | /// Returns the shape of a tensor. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -138,6 +141,11 @@ namespace Tensorflow | |||||
| public TensorShape TensorShape => rank < 0 ? new TensorShape() : tensor_util.to_shape(shape); | public TensorShape TensorShape => rank < 0 ? new TensorShape() : tensor_util.to_shape(shape); | ||||
| /// <summary> | |||||
| /// Keras History: (Layer, (node_index, tensor_index)) | |||||
| /// </summary> | |||||
| public List<Layer> KerasHistory = new List<Layer>(); | |||||
| /// <summary> | /// <summary> | ||||
| /// Updates the shape of this tensor. | /// Updates the shape of this tensor. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -13,7 +13,7 @@ namespace TensorFlowNET.UnitTest.Keras | |||||
| /// <summary> | /// <summary> | ||||
| /// https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/layers | /// https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/layers | ||||
| /// </summary> | /// </summary> | ||||
| [TestClass, Ignore] | |||||
| [TestClass] | |||||
| public class LayersTest : GraphModeTestBase | public class LayersTest : GraphModeTestBase | ||||
| { | { | ||||
| [TestMethod] | [TestMethod] | ||||
| @@ -23,11 +23,15 @@ namespace TensorFlowNET.UnitTest.Keras | |||||
| model.add(tf.keras.Input(shape: 16)); | model.add(tf.keras.Input(shape: 16)); | ||||
| } | } | ||||
| [TestMethod] | |||||
| /// <summary> | |||||
| /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding | |||||
| /// </summary> | |||||
| [TestMethod, Ignore] | |||||
| public void Embedding() | public void Embedding() | ||||
| { | { | ||||
| var model = new Sequential(); | var model = new Sequential(); | ||||
| model.add(new Embedding(1000, 64, input_length: 10)); | |||||
| var layer = tf.keras.layers.Embedding(1000, 64, input_length: 10); | |||||
| model.add(layer); | |||||
| // the model will take as input an integer matrix of size (batch, | // the model will take as input an integer matrix of size (batch, | ||||
| // input_length). | // input_length). | ||||
| // the largest integer (i.e. word index) in the input should be no larger | // the largest integer (i.e. word index) in the input should be no larger | ||||
| @@ -35,15 +39,32 @@ namespace TensorFlowNET.UnitTest.Keras | |||||
| // now model.output_shape == (None, 10, 64), where None is the batch | // now model.output_shape == (None, 10, 64), where None is the batch | ||||
| // dimension. | // dimension. | ||||
| var input_array = np.random.randint(1000, size: (32, 10)); | var input_array = np.random.randint(1000, size: (32, 10)); | ||||
| model.compile("rmsprop", "mse"); | |||||
| // model.compile("rmsprop", "mse"); | |||||
| // output_array = model.predict(input_array) | |||||
| } | } | ||||
| /// <summary> | |||||
| /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense | |||||
| /// </summary> | |||||
| [TestMethod] | [TestMethod] | ||||
| public void Dense() | public void Dense() | ||||
| { | { | ||||
| // Create a `Sequential` model and add a Dense layer as the first layer. | |||||
| var model = tf.keras.Sequential(); | var model = tf.keras.Sequential(); | ||||
| var dense_layer = tf.keras.layers.Dense(5, input_shape: 3); | |||||
| model.add(dense_layer); | |||||
| model.add(tf.keras.Input(shape: 16)); | |||||
| model.add(tf.keras.layers.Dense(32, activation: tf.keras.activations.Relu)); | |||||
| // Now the model will take as input arrays of shape (None, 16) | |||||
| // and output arrays of shape (None, 32). | |||||
| // Note that after the first layer, you don't need to specify | |||||
| // the size of the input anymore: | |||||
| model.add(tf.keras.layers.Dense(32)); | |||||
| Assert.AreEqual((-1, 32), model.output_shape); | |||||
| } | |||||
| [TestMethod] | |||||
| public void SimpleRNN() | |||||
| { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||