| @@ -38,13 +38,22 @@ namespace Tensorflow | |||
| var batch_size = batch_shape[0]; | |||
| var shape = batch_shape.Skip(1).ToArray(); | |||
| var input_layer = new InputLayer( | |||
| input_shape: shape, | |||
| batch_size: batch_size, | |||
| name: name, | |||
| dtype: dtype, | |||
| sparse: sparse, | |||
| input_tensor: tensor); | |||
| InputLayer input_layer = null; | |||
| if (batch_shape != null) | |||
| input_layer = new InputLayer( | |||
| batch_input_shape: batch_shape, | |||
| name: name, | |||
| dtype: dtype, | |||
| sparse: sparse, | |||
| input_tensor: tensor); | |||
| else | |||
| input_layer = new InputLayer( | |||
| input_shape: shape, | |||
| batch_size: batch_size, | |||
| name: name, | |||
| dtype: dtype, | |||
| sparse: sparse, | |||
| input_tensor: tensor); | |||
| var outputs = input_layer.inbound_nodes[0].output_tensors; | |||
| @@ -1,11 +1,33 @@ | |||
| namespace Tensorflow.Keras.Engine | |||
| using Tensorflow.Keras.Optimizers; | |||
| namespace Tensorflow.Keras.Engine | |||
| { | |||
| public class Model : Network | |||
| { | |||
| bool _cloning; | |||
| bool _is_compiled; | |||
| string loss; | |||
| IOptimizer optimizer; | |||
| public Model(string name = null) | |||
| : base(name: name) | |||
| { | |||
| } | |||
| public void compile(string optimizerName, string lossName) | |||
| { | |||
| switch (optimizerName) | |||
| { | |||
| case "rmsprop": | |||
| optimizer = new RMSprop(); | |||
| break; | |||
| } | |||
| loss = lossName; | |||
| _is_compiled = true; | |||
| // Prepare list of loss functions, same size of model outputs. | |||
| } | |||
| } | |||
| } | |||
| @@ -20,6 +20,9 @@ namespace Tensorflow.Keras.Engine | |||
| { | |||
| public class Sequential : Model, IObjectLife | |||
| { | |||
| bool _is_graph_network; | |||
| Tensor[] outputs; | |||
| public Sequential(string name = null) | |||
| : base(name: name) | |||
| { | |||
| @@ -42,21 +45,40 @@ namespace Tensorflow.Keras.Engine | |||
| var set_inputs = false; | |||
| if(_layers.Count == 0) | |||
| { | |||
| var (batch_shape, dtype) = (layer._batch_input_shape, layer._dtype); | |||
| if(batch_shape != null) | |||
| if(layer is InputLayer) | |||
| { | |||
| // Instantiate an input layer. | |||
| var x = keras.layers.Input( | |||
| batch_shape: batch_shape, | |||
| dtype: dtype, | |||
| name: layer.name + "_input"); | |||
| // This will build the current layer | |||
| // and create the node connecting the current layer | |||
| // to the input layer we just created. | |||
| layer.__call__(x); | |||
| set_inputs = true; | |||
| } | |||
| else | |||
| { | |||
| var (batch_shape, dtype) = (layer._batch_input_shape, layer._dtype); | |||
| if (batch_shape != null) | |||
| { | |||
| // Instantiate an input layer. | |||
| var x = keras.layers.Input( | |||
| batch_shape: batch_shape, | |||
| dtype: dtype, | |||
| name: layer.name + "_input"); | |||
| // This will build the current layer | |||
| // and create the node connecting the current layer | |||
| // to the input layer we just created. | |||
| layer.__call__(x); | |||
| set_inputs = true; | |||
| } | |||
| } | |||
| if (set_inputs) | |||
| { | |||
| // If an input layer (placeholder) is available. | |||
| // outputs = layer._inbound_nodes; | |||
| } | |||
| } | |||
| if (set_inputs || _is_graph_network) | |||
| { | |||
| } | |||
| } | |||
| @@ -95,7 +95,7 @@ namespace Tensorflow.Keras.Layers | |||
| var param_shape = new int[] { input_shape.dims[axis[0]] }; | |||
| if (scale) | |||
| gamma = add_weight("gamma", | |||
| gamma = (RefVariable)add_weight("gamma", | |||
| param_shape, | |||
| dtype: param_dtype, | |||
| initializer: gamma_initializer, | |||
| @@ -104,7 +104,7 @@ namespace Tensorflow.Keras.Layers | |||
| throw new NotImplementedException("add_weight gamma"); | |||
| if (center) | |||
| beta = add_weight("beta", | |||
| beta = (RefVariable)add_weight("beta", | |||
| param_shape, | |||
| dtype: param_dtype, | |||
| initializer: beta_initializer, | |||
| @@ -117,7 +117,7 @@ namespace Tensorflow.Keras.Layers | |||
| } | |||
| moving_mean = add_weight("moving_mean", | |||
| moving_mean = (RefVariable)add_weight("moving_mean", | |||
| param_shape, | |||
| dtype: param_dtype, | |||
| initializer: moving_mean_initializer, | |||
| @@ -125,7 +125,7 @@ namespace Tensorflow.Keras.Layers | |||
| trainable: false, | |||
| aggregation: VariableAggregation.Mean); | |||
| moving_variance = add_weight("moving_variance", | |||
| moving_variance = (RefVariable)add_weight("moving_variance", | |||
| shape: param_shape, | |||
| dtype: param_dtype, | |||
| initializer: moving_variance_initializer, | |||
| @@ -75,13 +75,13 @@ namespace Tensorflow.Keras.Layers | |||
| input_shape.dims[input_shape.ndim + channel_axis] : | |||
| input_shape.dims[channel_axis]; | |||
| var kernel_shape = new int[] { kernel_size[0], kernel_size[1], input_dim, filters }; | |||
| kernel = add_weight(name: "kernel", | |||
| kernel = (RefVariable)add_weight(name: "kernel", | |||
| shape: kernel_shape, | |||
| initializer: kernel_initializer, | |||
| trainable: true, | |||
| dtype: _dtype); | |||
| if (use_bias) | |||
| bias = add_weight(name: "bias", | |||
| bias = (RefVariable)add_weight(name: "bias", | |||
| shape: new int[] { filters }, | |||
| initializer: bias_initializer, | |||
| trainable: true, | |||
| @@ -55,14 +55,14 @@ namespace Tensorflow.Keras.Layers | |||
| var axes = new Dictionary<int, int>(); | |||
| axes[-1] = last_dim; | |||
| input_spec = new InputSpec(min_ndim: 2, axes: axes); | |||
| kernel = add_weight( | |||
| kernel = (RefVariable)add_weight( | |||
| "kernel", | |||
| shape: new int[] { last_dim, units }, | |||
| initializer: kernel_initializer, | |||
| dtype: _dtype, | |||
| trainable: true); | |||
| if (use_bias) | |||
| bias = add_weight( | |||
| bias = (RefVariable)add_weight( | |||
| "bias", | |||
| shape: new int[] { units }, | |||
| initializer: bias_initializer, | |||
| @@ -23,20 +23,23 @@ namespace Tensorflow.Keras.Layers | |||
| private int input_dim; | |||
| private int output_dim; | |||
| private bool mask_zero; | |||
| public RefVariable embeddings; | |||
| public VariableV1 embeddings; | |||
| public IInitializer embeddings_initializer; | |||
| int input_length; | |||
| public Embedding(int input_dim, int output_dim, | |||
| IInitializer embeddings_initializer = null, | |||
| bool mask_zero = false, | |||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
| int[] input_shape = null) : base(dtype: dtype, input_shape: input_shape) | |||
| int[] input_shape = null, | |||
| int input_length = -1) : base(dtype: dtype, input_shape: input_shape ?? new[] { input_length }) | |||
| { | |||
| this.input_dim = input_dim; | |||
| this.output_dim = output_dim; | |||
| this.embeddings_initializer = embeddings_initializer == null ? tf.uniform_initializer : embeddings_initializer; | |||
| this.mask_zero = mask_zero; | |||
| supports_masking = mask_zero; | |||
| this.input_length = input_length; | |||
| } | |||
| protected override void build(TensorShape input_shape) | |||
| @@ -46,5 +49,15 @@ namespace Tensorflow.Keras.Layers | |||
| name: "embeddings"); | |||
| built = true; | |||
| } | |||
| protected override Tensor call(Tensor inputs, Tensor training = null) | |||
| { | |||
| var dtype = inputs.dtype; | |||
| if (dtype != tf.int32 && dtype != tf.int64) | |||
| inputs = math_ops.cast(inputs, tf.int32); | |||
| var @out = embedding_ops.embedding_lookup(embeddings, inputs); | |||
| return @out; | |||
| } | |||
| } | |||
| } | |||
| @@ -15,6 +15,8 @@ | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| @@ -28,21 +30,47 @@ namespace Tensorflow.Keras.Layers | |||
| public bool is_placeholder; | |||
| public InputLayer(int[] input_shape = null, | |||
| int[] batch_input_shape = null, | |||
| int? batch_size = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| string name = null, | |||
| bool sparse = false, | |||
| Tensor input_tensor = null) | |||
| Tensor input_tensor = null) : base(dtype: dtype, name: name) | |||
| { | |||
| built = true; | |||
| this.sparse = sparse; | |||
| this.batch_size = batch_size; | |||
| this.supports_masking = true; | |||
| if(batch_input_shape != null) | |||
| { | |||
| batch_size = batch_input_shape[0]; | |||
| input_shape = batch_input_shape.Skip(1).ToArray(); | |||
| } | |||
| // moved to base class | |||
| if (string.IsNullOrEmpty(name)) | |||
| { | |||
| var prefix = "input"; | |||
| name = prefix + '_' + backend.get_uid(prefix); | |||
| } | |||
| if (input_tensor == null) | |||
| { | |||
| var batch_input_shape = new int[] { batch_size.HasValue ? batch_size.Value : -1, -1 }; | |||
| if(input_shape != null) | |||
| { | |||
| var dims = new List<int> { batch_size.HasValue ? batch_size.Value : -1 }; | |||
| dims.AddRange(input_shape); | |||
| batch_input_shape = dims.ToArray(); | |||
| } | |||
| else | |||
| { | |||
| batch_input_shape = null; | |||
| } | |||
| var graph = backend.get_graph().as_default(); | |||
| // In graph mode, create a graph placeholder to call the layer on. | |||
| if (sparse) | |||
| { | |||
| throw new NotImplementedException("InputLayer sparse is true"); | |||
| @@ -59,6 +87,10 @@ namespace Tensorflow.Keras.Layers | |||
| _batch_input_shape = batch_input_shape; | |||
| } | |||
| // Create an input node to add to self.outbound_node | |||
| // and set output_tensors' _keras_history. | |||
| // input_tensor._keras_history = base_layer.KerasHistory(self, 0, 0) | |||
| // input_tensor._keras_mask = None | |||
| new Node(this, | |||
| inbound_layers: new Layer[0], | |||
| node_indices: new int[0], | |||
| @@ -51,7 +51,7 @@ namespace Tensorflow.Keras.Layers | |||
| /// </summary> | |||
| protected InputSpec input_spec; | |||
| protected bool supports_masking; | |||
| protected List<RefVariable> _trainable_weights; | |||
| protected List<VariableV1> _trainable_weights; | |||
| private string _name; | |||
| public string name => _name; | |||
| protected string _base_name; | |||
| @@ -65,6 +65,8 @@ namespace Tensorflow.Keras.Layers | |||
| private List<Node> _outbound_nodes; | |||
| public List<Node> outbound_nodes => _outbound_nodes; | |||
| float _initial_weights; | |||
| public Layer(bool trainable = true, | |||
| string name = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| @@ -81,13 +83,18 @@ namespace Tensorflow.Keras.Layers | |||
| this.supports_masking = false; | |||
| _init_set_name(name); | |||
| _trainable_weights = new List<RefVariable>(); | |||
| _trainable_weights = new List<VariableV1>(); | |||
| _compute_previous_mask = false; | |||
| _updates = new List<Operation>(); | |||
| // Manage input shape information if passed. | |||
| _batch_input_shape = new int[] { -1, -1 }; | |||
| if(input_shape != null) | |||
| { | |||
| var shapes = new List<int> { -1 }; | |||
| shapes.AddRange(input_shape); | |||
| _batch_input_shape = shapes.ToArray(); | |||
| } | |||
| _dtype = dtype; | |||
| @@ -186,12 +193,12 @@ namespace Tensorflow.Keras.Layers | |||
| built = true; | |||
| } | |||
| protected virtual RefVariable add_weight(string name, | |||
| protected virtual VariableV1 add_weight(string name, | |||
| int[] shape, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| IInitializer initializer = null, | |||
| bool? trainable = null, | |||
| Func<string, int[], TF_DataType, IInitializer, bool, RefVariable> getter = null) | |||
| Func<string, int[], TF_DataType, IInitializer, bool, VariableV1> getter = null) | |||
| { | |||
| if (dtype == TF_DataType.DtInvalid) | |||
| dtype = TF_DataType.TF_FLOAT; | |||
| @@ -0,0 +1,10 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.Optimizers | |||
| { | |||
| public interface IOptimizer | |||
| { | |||
| } | |||
| } | |||
| @@ -0,0 +1,14 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Train; | |||
| namespace Tensorflow.Keras.Optimizers | |||
| { | |||
| /// <summary> | |||
| /// Updated base class for optimizers. | |||
| /// </summary> | |||
| public class OptimizerV2 : Trackable, IOptimizer | |||
| { | |||
| } | |||
| } | |||
| @@ -0,0 +1,14 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.Optimizers | |||
| { | |||
| /// <summary> | |||
| /// Optimizer that implements the RMSprop algorithm. | |||
| /// </summary> | |||
| public class RMSprop : OptimizerV2 | |||
| { | |||
| } | |||
| } | |||
| @@ -42,12 +42,12 @@ namespace Tensorflow.Keras | |||
| /// Allows to give unique autogenerated names to layers, in a graph-specific way. | |||
| /// </summary> | |||
| public static Dictionary<Graph, Dictionary<(string, string), int>> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<(string, string), int>>(); | |||
| public static Dictionary<string, RefVariable> _GRAPH_VARIABLES = new Dictionary<string, RefVariable>(); | |||
| public static Dictionary<string, VariableV1> _GRAPH_VARIABLES = new Dictionary<string, VariableV1>(); | |||
| public static Dictionary<string, Optimizer> _GRAPH_TF_OPTIMIZERS = new Dictionary<string, Optimizer>(); | |||
| public static _DummyEagerGraph _DUMMY_EAGER_GRAPH = new _DummyEagerGraph(); | |||
| public static void track_variable(RefVariable v) | |||
| public static void track_variable(VariableV1 v) | |||
| { | |||
| var graph = v.graph; | |||
| _GRAPH_VARIABLES[graph.graph_key] = v; | |||
| @@ -42,7 +42,7 @@ namespace Tensorflow.Layers | |||
| this._reuse = _reuse; | |||
| // Avoid an incorrect lint error | |||
| _trainable_weights = new List<RefVariable>(); | |||
| _trainable_weights = new List<VariableV1>(); | |||
| this.built = false; | |||
| _keras_style = false; | |||
| } | |||
| @@ -109,7 +109,7 @@ namespace Tensorflow.Layers | |||
| /// <param name="synchronization"></param> | |||
| /// <param name="aggregation"></param> | |||
| /// <returns></returns> | |||
| protected virtual RefVariable add_weight(string name, | |||
| protected virtual VariableV1 add_weight(string name, | |||
| int[] shape, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| IInitializer initializer = null, | |||
| @@ -600,7 +600,7 @@ namespace Tensorflow | |||
| return gen_array_ops.concat_v2(values, axis, name: name); | |||
| } | |||
| public static Tensor gather(Tensor @params, Tensor indices, string name = null, int axis = 0) | |||
| public static Tensor gather<T1, T2>(T1 @params, T2 indices, string name = null, int axis = 0) | |||
| => gen_array_ops.gather_v2(@params, indices, axis, name: name); | |||
| public static Tensor transpose<T1, T2>(T1 a, T2 perm, string name = "transpose", bool conjugate = false) | |||
| @@ -52,6 +52,38 @@ namespace Tensorflow | |||
| }); | |||
| } | |||
| /// <summary> | |||
| /// Helper function for embedding_lookup and _compute_sampled_logits. | |||
| /// </summary> | |||
| /// <param name="params"></param> | |||
| /// <param name="ids"></param> | |||
| /// <param name="partition_strategy"></param> | |||
| /// <param name="name"></param> | |||
| /// <param name="max_norm"></param> | |||
| /// <returns></returns> | |||
| public static Tensor _embedding_lookup_and_transform(VariableV1 @params, | |||
| Tensor ids, | |||
| string partition_strategy = "mod", | |||
| string name = null, | |||
| string max_norm = null) | |||
| { | |||
| return tf_with(ops.name_scope(name, "embedding_lookup", new { @params, ids }), scope => | |||
| { | |||
| name = scope; | |||
| int np = 1; | |||
| ids = ops.convert_to_tensor(ids, name: "ids"); | |||
| if (np == 1) | |||
| { | |||
| var gather = array_ops.gather(@params, ids, name: name); | |||
| var result = _clip(gather, ids, max_norm); | |||
| return array_ops.identity(result); | |||
| } | |||
| throw new NotImplementedException("_embedding_lookup_and_transform"); | |||
| }); | |||
| } | |||
| public static Tensor _embedding_lookup_and_transform(Tensor[] @params, | |||
| Tensor ids, | |||
| string partition_strategy = "mod", | |||
| @@ -98,5 +130,18 @@ namespace Tensorflow | |||
| name: name, | |||
| max_norm: max_norm); | |||
| } | |||
| public static Tensor embedding_lookup(VariableV1 @params, Tensor ids, | |||
| string partition_strategy = "mod", | |||
| string name = null, | |||
| bool validate_indices = true, | |||
| string max_norm = null) | |||
| { | |||
| return _embedding_lookup_and_transform(@params: @params, | |||
| ids: ids, | |||
| partition_strategy: partition_strategy, | |||
| name: name, | |||
| max_norm: max_norm); | |||
| } | |||
| } | |||
| } | |||
| @@ -106,7 +106,7 @@ namespace Tensorflow | |||
| return _op.outputs[0]; | |||
| } | |||
| public static Tensor gather_v2(Tensor @params, Tensor indices, int axis, string name = null) | |||
| public static Tensor gather_v2<T1, T2>(T1 @params, T2 indices, int axis, string name = null) | |||
| { | |||
| var _op = _op_def_lib._apply_op_helper("GatherV2", name: name, new { @params, indices, axis }); | |||
| @@ -26,11 +26,11 @@ namespace Tensorflow.Train | |||
| /// Restore-on-create for a variable be saved with this `Checkpointable`. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| protected virtual RefVariable _add_variable_with_custom_getter(string name, | |||
| protected virtual VariableV1 _add_variable_with_custom_getter(string name, | |||
| int[] shape, | |||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
| IInitializer initializer = null, | |||
| Func<string, int[], TF_DataType, IInitializer, bool, RefVariable> getter = null, | |||
| Func<string, int[], TF_DataType, IInitializer, bool, VariableV1> getter = null, | |||
| bool overwrite = false, | |||
| bool trainable = false) | |||
| { | |||
| @@ -59,7 +59,7 @@ namespace Tensorflow.Train | |||
| // TODO | |||
| } | |||
| protected RefVariable _track_checkpointable(RefVariable checkpointable, string name, bool overwrite = false) | |||
| protected VariableV1 _track_checkpointable(VariableV1 checkpointable, string name, bool overwrite = false) | |||
| { | |||
| return checkpointable; | |||
| } | |||
| @@ -28,14 +28,14 @@ namespace Tensorflow | |||
| public Tensor _initial_value; | |||
| public string _graph_key; | |||
| public bool _trainable; | |||
| public Tensor _variable; | |||
| public Tensor _snapshot; | |||
| public bool _save_slice_info; | |||
| private Operation _initializer_op; | |||
| public override Operation initializer => _initializer_op; | |||
| public override Operation op => _variable.op; | |||
| public Graph graph => _variable.graph; | |||
| public TF_DataType dtype => _variable.dtype; | |||
| public TensorShape shape => tensor_util.to_shape(_variable.shape); | |||
| @@ -34,7 +34,8 @@ namespace Tensorflow | |||
| public virtual Tensor graph_element { get; } | |||
| public virtual Operation op { get; } | |||
| public virtual Operation initializer { get; } | |||
| public Tensor _variable; | |||
| public Graph graph => _variable.graph; | |||
| public VariableV1(object initial_value = null, | |||
| bool trainable = true, | |||
| List<string> collections = null, | |||
| @@ -0,0 +1,29 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Layers; | |||
| using NumSharp; | |||
| namespace TensorFlowNET.UnitTest.Keras | |||
| { | |||
| [TestClass] | |||
| public class EmbeddingTest | |||
| { | |||
| [TestMethod] | |||
| public void Embedding() | |||
| { | |||
| var model = new Sequential(); | |||
| model.add(new Embedding(1000, 64, input_length: 10)); | |||
| // the model will take as input an integer matrix of size (batch, | |||
| // input_length). | |||
| // the largest integer (i.e. word index) in the input should be no larger | |||
| // than 999 (vocabulary size). | |||
| // now model.output_shape == (None, 10, 64), where None is the batch | |||
| // dimension. | |||
| var input_array = np.random.randint(1000, size: (32, 10)); | |||
| model.compile("rmsprop", "mse"); | |||
| } | |||
| } | |||
| } | |||