diff --git a/src/TensorFlowNET.Core/APIs/keras.layers.cs b/src/TensorFlowNET.Core/APIs/keras.layers.cs index aba24115..92900e76 100644 --- a/src/TensorFlowNET.Core/APIs/keras.layers.cs +++ b/src/TensorFlowNET.Core/APIs/keras.layers.cs @@ -38,13 +38,22 @@ namespace Tensorflow var batch_size = batch_shape[0]; var shape = batch_shape.Skip(1).ToArray(); - var input_layer = new InputLayer( - input_shape: shape, - batch_size: batch_size, - name: name, - dtype: dtype, - sparse: sparse, - input_tensor: tensor); + InputLayer input_layer = null; + if (batch_shape != null) + input_layer = new InputLayer( + batch_input_shape: batch_shape, + name: name, + dtype: dtype, + sparse: sparse, + input_tensor: tensor); + else + input_layer = new InputLayer( + input_shape: shape, + batch_size: batch_size, + name: name, + dtype: dtype, + sparse: sparse, + input_tensor: tensor); var outputs = input_layer.inbound_nodes[0].output_tensors; diff --git a/src/TensorFlowNET.Core/Keras/Engine/Model.cs b/src/TensorFlowNET.Core/Keras/Engine/Model.cs index d56c49bc..d4cde39e 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Model.cs @@ -1,11 +1,33 @@ -namespace Tensorflow.Keras.Engine +using Tensorflow.Keras.Optimizers; + +namespace Tensorflow.Keras.Engine { public class Model : Network { + bool _cloning; + bool _is_compiled; + string loss; + IOptimizer optimizer; + public Model(string name = null) : base(name: name) { } + + public void compile(string optimizerName, string lossName) + { + switch (optimizerName) + { + case "rmsprop": + optimizer = new RMSprop(); + break; + } + + loss = lossName; + _is_compiled = true; + + // Prepare list of loss functions, same size of model outputs. + } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs b/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs index e18b401c..e9f85530 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs @@ -20,6 +20,9 @@ namespace Tensorflow.Keras.Engine { public class Sequential : Model, IObjectLife { + bool _is_graph_network; + Tensor[] outputs; + public Sequential(string name = null) : base(name: name) { @@ -42,21 +45,40 @@ namespace Tensorflow.Keras.Engine var set_inputs = false; if(_layers.Count == 0) { - var (batch_shape, dtype) = (layer._batch_input_shape, layer._dtype); - if(batch_shape != null) + if(layer is InputLayer) { - // Instantiate an input layer. - var x = keras.layers.Input( - batch_shape: batch_shape, - dtype: dtype, - name: layer.name + "_input"); - - // This will build the current layer - // and create the node connecting the current layer - // to the input layer we just created. - layer.__call__(x); - set_inputs = true; + } + else + { + var (batch_shape, dtype) = (layer._batch_input_shape, layer._dtype); + if (batch_shape != null) + { + // Instantiate an input layer. + var x = keras.layers.Input( + batch_shape: batch_shape, + dtype: dtype, + name: layer.name + "_input"); + + // This will build the current layer + // and create the node connecting the current layer + // to the input layer we just created. + layer.__call__(x); + set_inputs = true; + } + } + + if (set_inputs) + { + // If an input layer (placeholder) is available. + // outputs = layer._inbound_nodes; + } + + } + + if (set_inputs || _is_graph_network) + { + } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs index 929b3a3f..530ca76c 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs @@ -95,7 +95,7 @@ namespace Tensorflow.Keras.Layers var param_shape = new int[] { input_shape.dims[axis[0]] }; if (scale) - gamma = add_weight("gamma", + gamma = (RefVariable)add_weight("gamma", param_shape, dtype: param_dtype, initializer: gamma_initializer, @@ -104,7 +104,7 @@ namespace Tensorflow.Keras.Layers throw new NotImplementedException("add_weight gamma"); if (center) - beta = add_weight("beta", + beta = (RefVariable)add_weight("beta", param_shape, dtype: param_dtype, initializer: beta_initializer, @@ -117,7 +117,7 @@ namespace Tensorflow.Keras.Layers } - moving_mean = add_weight("moving_mean", + moving_mean = (RefVariable)add_weight("moving_mean", param_shape, dtype: param_dtype, initializer: moving_mean_initializer, @@ -125,7 +125,7 @@ namespace Tensorflow.Keras.Layers trainable: false, aggregation: VariableAggregation.Mean); - moving_variance = add_weight("moving_variance", + moving_variance = (RefVariable)add_weight("moving_variance", shape: param_shape, dtype: param_dtype, initializer: moving_variance_initializer, diff --git a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs index 8319041f..dc40ae8c 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs @@ -75,13 +75,13 @@ namespace Tensorflow.Keras.Layers input_shape.dims[input_shape.ndim + channel_axis] : input_shape.dims[channel_axis]; var kernel_shape = new int[] { kernel_size[0], kernel_size[1], input_dim, filters }; - kernel = add_weight(name: "kernel", + kernel = (RefVariable)add_weight(name: "kernel", shape: kernel_shape, initializer: kernel_initializer, trainable: true, dtype: _dtype); if (use_bias) - bias = add_weight(name: "bias", + bias = (RefVariable)add_weight(name: "bias", shape: new int[] { filters }, initializer: bias_initializer, trainable: true, diff --git a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs index adfae5d1..2564da6d 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs @@ -55,14 +55,14 @@ namespace Tensorflow.Keras.Layers var axes = new Dictionary(); axes[-1] = last_dim; input_spec = new InputSpec(min_ndim: 2, axes: axes); - kernel = add_weight( + kernel = (RefVariable)add_weight( "kernel", shape: new int[] { last_dim, units }, initializer: kernel_initializer, dtype: _dtype, trainable: true); if (use_bias) - bias = add_weight( + bias = (RefVariable)add_weight( "bias", shape: new int[] { units }, initializer: bias_initializer, diff --git a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs index 37f15baf..f10499c4 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs @@ -23,20 +23,23 @@ namespace Tensorflow.Keras.Layers private int input_dim; private int output_dim; private bool mask_zero; - public RefVariable embeddings; + public VariableV1 embeddings; public IInitializer embeddings_initializer; + int input_length; public Embedding(int input_dim, int output_dim, IInitializer embeddings_initializer = null, bool mask_zero = false, TF_DataType dtype = TF_DataType.TF_FLOAT, - int[] input_shape = null) : base(dtype: dtype, input_shape: input_shape) + int[] input_shape = null, + int input_length = -1) : base(dtype: dtype, input_shape: input_shape ?? new[] { input_length }) { this.input_dim = input_dim; this.output_dim = output_dim; this.embeddings_initializer = embeddings_initializer == null ? tf.uniform_initializer : embeddings_initializer; this.mask_zero = mask_zero; supports_masking = mask_zero; + this.input_length = input_length; } protected override void build(TensorShape input_shape) @@ -46,5 +49,15 @@ namespace Tensorflow.Keras.Layers name: "embeddings"); built = true; } + + protected override Tensor call(Tensor inputs, Tensor training = null) + { + var dtype = inputs.dtype; + if (dtype != tf.int32 && dtype != tf.int64) + inputs = math_ops.cast(inputs, tf.int32); + + var @out = embedding_ops.embedding_lookup(embeddings, inputs); + return @out; + } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs b/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs index ce029fa7..be5515ec 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs @@ -15,6 +15,8 @@ ******************************************************************************/ using System; +using System.Collections.Generic; +using System.Linq; namespace Tensorflow.Keras.Layers { @@ -28,21 +30,47 @@ namespace Tensorflow.Keras.Layers public bool is_placeholder; public InputLayer(int[] input_shape = null, + int[] batch_input_shape = null, int? batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool sparse = false, - Tensor input_tensor = null) + Tensor input_tensor = null) : base(dtype: dtype, name: name) { built = true; this.sparse = sparse; this.batch_size = batch_size; this.supports_masking = true; + if(batch_input_shape != null) + { + batch_size = batch_input_shape[0]; + input_shape = batch_input_shape.Skip(1).ToArray(); + } + + // moved to base class + if (string.IsNullOrEmpty(name)) + { + var prefix = "input"; + name = prefix + '_' + backend.get_uid(prefix); + } + if (input_tensor == null) { - var batch_input_shape = new int[] { batch_size.HasValue ? batch_size.Value : -1, -1 }; + if(input_shape != null) + { + var dims = new List { batch_size.HasValue ? batch_size.Value : -1 }; + dims.AddRange(input_shape); + batch_input_shape = dims.ToArray(); + } + else + { + batch_input_shape = null; + } + + var graph = backend.get_graph().as_default(); + // In graph mode, create a graph placeholder to call the layer on. if (sparse) { throw new NotImplementedException("InputLayer sparse is true"); @@ -59,6 +87,10 @@ namespace Tensorflow.Keras.Layers _batch_input_shape = batch_input_shape; } + // Create an input node to add to self.outbound_node + // and set output_tensors' _keras_history. + // input_tensor._keras_history = base_layer.KerasHistory(self, 0, 0) + // input_tensor._keras_mask = None new Node(this, inbound_layers: new Layer[0], node_indices: new int[0], diff --git a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs index 6681ec56..22cef8e1 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs @@ -51,7 +51,7 @@ namespace Tensorflow.Keras.Layers /// protected InputSpec input_spec; protected bool supports_masking; - protected List _trainable_weights; + protected List _trainable_weights; private string _name; public string name => _name; protected string _base_name; @@ -65,6 +65,8 @@ namespace Tensorflow.Keras.Layers private List _outbound_nodes; public List outbound_nodes => _outbound_nodes; + float _initial_weights; + public Layer(bool trainable = true, string name = null, TF_DataType dtype = TF_DataType.DtInvalid, @@ -81,13 +83,18 @@ namespace Tensorflow.Keras.Layers this.supports_masking = false; _init_set_name(name); - _trainable_weights = new List(); + _trainable_weights = new List(); _compute_previous_mask = false; _updates = new List(); // Manage input shape information if passed. - - _batch_input_shape = new int[] { -1, -1 }; + if(input_shape != null) + { + var shapes = new List { -1 }; + shapes.AddRange(input_shape); + _batch_input_shape = shapes.ToArray(); + } + _dtype = dtype; @@ -186,12 +193,12 @@ namespace Tensorflow.Keras.Layers built = true; } - protected virtual RefVariable add_weight(string name, + protected virtual VariableV1 add_weight(string name, int[] shape, TF_DataType dtype = TF_DataType.DtInvalid, IInitializer initializer = null, bool? trainable = null, - Func getter = null) + Func getter = null) { if (dtype == TF_DataType.DtInvalid) dtype = TF_DataType.TF_FLOAT; diff --git a/src/TensorFlowNET.Core/Keras/Optimizers/IOptimizer.cs b/src/TensorFlowNET.Core/Keras/Optimizers/IOptimizer.cs new file mode 100644 index 00000000..0c1d411e --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Optimizers/IOptimizer.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Optimizers +{ + public interface IOptimizer + { + } +} diff --git a/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs b/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs new file mode 100644 index 00000000..2f22a721 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Train; + +namespace Tensorflow.Keras.Optimizers +{ + /// + /// Updated base class for optimizers. + /// + public class OptimizerV2 : Trackable, IOptimizer + { + } +} diff --git a/src/TensorFlowNET.Core/Keras/Optimizers/RMSprop.cs b/src/TensorFlowNET.Core/Keras/Optimizers/RMSprop.cs new file mode 100644 index 00000000..51b65b57 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Optimizers/RMSprop.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Optimizers +{ + /// + /// Optimizer that implements the RMSprop algorithm. + /// + public class RMSprop : OptimizerV2 + { + + } +} diff --git a/src/TensorFlowNET.Core/Keras/backend.cs b/src/TensorFlowNET.Core/Keras/backend.cs index 46769bd8..73d7d335 100644 --- a/src/TensorFlowNET.Core/Keras/backend.cs +++ b/src/TensorFlowNET.Core/Keras/backend.cs @@ -42,12 +42,12 @@ namespace Tensorflow.Keras /// Allows to give unique autogenerated names to layers, in a graph-specific way. /// public static Dictionary> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary>(); - public static Dictionary _GRAPH_VARIABLES = new Dictionary(); + public static Dictionary _GRAPH_VARIABLES = new Dictionary(); public static Dictionary _GRAPH_TF_OPTIMIZERS = new Dictionary(); public static _DummyEagerGraph _DUMMY_EAGER_GRAPH = new _DummyEagerGraph(); - public static void track_variable(RefVariable v) + public static void track_variable(VariableV1 v) { var graph = v.graph; _GRAPH_VARIABLES[graph.graph_key] = v; diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index 444c2dd4..138f0fc7 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -42,7 +42,7 @@ namespace Tensorflow.Layers this._reuse = _reuse; // Avoid an incorrect lint error - _trainable_weights = new List(); + _trainable_weights = new List(); this.built = false; _keras_style = false; } @@ -109,7 +109,7 @@ namespace Tensorflow.Layers /// /// /// - protected virtual RefVariable add_weight(string name, + protected virtual VariableV1 add_weight(string name, int[] shape, TF_DataType dtype = TF_DataType.DtInvalid, IInitializer initializer = null, diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index 12094e41..b7ef6440 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -600,7 +600,7 @@ namespace Tensorflow return gen_array_ops.concat_v2(values, axis, name: name); } - public static Tensor gather(Tensor @params, Tensor indices, string name = null, int axis = 0) + public static Tensor gather(T1 @params, T2 indices, string name = null, int axis = 0) => gen_array_ops.gather_v2(@params, indices, axis, name: name); public static Tensor transpose(T1 a, T2 perm, string name = "transpose", bool conjugate = false) diff --git a/src/TensorFlowNET.Core/Operations/embedding_ops.cs b/src/TensorFlowNET.Core/Operations/embedding_ops.cs index 3c02e825..1b23fab3 100644 --- a/src/TensorFlowNET.Core/Operations/embedding_ops.cs +++ b/src/TensorFlowNET.Core/Operations/embedding_ops.cs @@ -52,6 +52,38 @@ namespace Tensorflow }); } + /// + /// Helper function for embedding_lookup and _compute_sampled_logits. + /// + /// + /// + /// + /// + /// + /// + public static Tensor _embedding_lookup_and_transform(VariableV1 @params, + Tensor ids, + string partition_strategy = "mod", + string name = null, + string max_norm = null) + { + return tf_with(ops.name_scope(name, "embedding_lookup", new { @params, ids }), scope => + { + name = scope; + int np = 1; + ids = ops.convert_to_tensor(ids, name: "ids"); + if (np == 1) + { + var gather = array_ops.gather(@params, ids, name: name); + var result = _clip(gather, ids, max_norm); + + return array_ops.identity(result); + } + + throw new NotImplementedException("_embedding_lookup_and_transform"); + }); + } + public static Tensor _embedding_lookup_and_transform(Tensor[] @params, Tensor ids, string partition_strategy = "mod", @@ -98,5 +130,18 @@ namespace Tensorflow name: name, max_norm: max_norm); } + + public static Tensor embedding_lookup(VariableV1 @params, Tensor ids, + string partition_strategy = "mod", + string name = null, + bool validate_indices = true, + string max_norm = null) + { + return _embedding_lookup_and_transform(@params: @params, + ids: ids, + partition_strategy: partition_strategy, + name: name, + max_norm: max_norm); + } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 847ace24..01231035 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -106,7 +106,7 @@ namespace Tensorflow return _op.outputs[0]; } - public static Tensor gather_v2(Tensor @params, Tensor indices, int axis, string name = null) + public static Tensor gather_v2(T1 @params, T2 indices, int axis, string name = null) { var _op = _op_def_lib._apply_op_helper("GatherV2", name: name, new { @params, indices, axis }); diff --git a/src/TensorFlowNET.Core/Train/Trackable.cs b/src/TensorFlowNET.Core/Train/Trackable.cs index a718c869..975546f7 100644 --- a/src/TensorFlowNET.Core/Train/Trackable.cs +++ b/src/TensorFlowNET.Core/Train/Trackable.cs @@ -26,11 +26,11 @@ namespace Tensorflow.Train /// Restore-on-create for a variable be saved with this `Checkpointable`. /// /// - protected virtual RefVariable _add_variable_with_custom_getter(string name, + protected virtual VariableV1 _add_variable_with_custom_getter(string name, int[] shape, TF_DataType dtype = TF_DataType.TF_FLOAT, IInitializer initializer = null, - Func getter = null, + Func getter = null, bool overwrite = false, bool trainable = false) { @@ -59,7 +59,7 @@ namespace Tensorflow.Train // TODO } - protected RefVariable _track_checkpointable(RefVariable checkpointable, string name, bool overwrite = false) + protected VariableV1 _track_checkpointable(VariableV1 checkpointable, string name, bool overwrite = false) { return checkpointable; } diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 97e1d0f4..35e0da87 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -28,14 +28,14 @@ namespace Tensorflow public Tensor _initial_value; public string _graph_key; public bool _trainable; - public Tensor _variable; + public Tensor _snapshot; public bool _save_slice_info; private Operation _initializer_op; public override Operation initializer => _initializer_op; public override Operation op => _variable.op; - public Graph graph => _variable.graph; + public TF_DataType dtype => _variable.dtype; public TensorShape shape => tensor_util.to_shape(_variable.shape); diff --git a/src/TensorFlowNET.Core/Variables/VariableV1.cs b/src/TensorFlowNET.Core/Variables/VariableV1.cs index eb3349fd..48e1952c 100644 --- a/src/TensorFlowNET.Core/Variables/VariableV1.cs +++ b/src/TensorFlowNET.Core/Variables/VariableV1.cs @@ -34,7 +34,8 @@ namespace Tensorflow public virtual Tensor graph_element { get; } public virtual Operation op { get; } public virtual Operation initializer { get; } - + public Tensor _variable; + public Graph graph => _variable.graph; public VariableV1(object initial_value = null, bool trainable = true, List collections = null, diff --git a/test/TensorFlowNET.UnitTest/Keras/EmbeddingTest.cs b/test/TensorFlowNET.UnitTest/Keras/EmbeddingTest.cs new file mode 100644 index 00000000..896ad430 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Keras/EmbeddingTest.cs @@ -0,0 +1,29 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; +using NumSharp; + +namespace TensorFlowNET.UnitTest.Keras +{ + [TestClass] + public class EmbeddingTest + { + [TestMethod] + public void Embedding() + { + var model = new Sequential(); + model.add(new Embedding(1000, 64, input_length: 10)); + // the model will take as input an integer matrix of size (batch, + // input_length). + // the largest integer (i.e. word index) in the input should be no larger + // than 999 (vocabulary size). + // now model.output_shape == (None, 10, 64), where None is the batch + // dimension. + var input_array = np.random.randint(1000, size: (32, 10)); + model.compile("rmsprop", "mse"); + } + } +}