| @@ -1,5 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras; | using Tensorflow.Keras; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| @@ -12,10 +13,30 @@ namespace Tensorflow | |||||
| public static class layers | public static class layers | ||||
| { | { | ||||
| public static Embedding Embedding(int input_dim, int output_dim, | public static Embedding Embedding(int input_dim, int output_dim, | ||||
| string embeddings_initializer = "uniform", | |||||
| bool mask_zero = false) => new Embedding(input_dim, output_dim, | |||||
| embeddings_initializer, | |||||
| mask_zero); | |||||
| IInitializer embeddings_initializer = null, | |||||
| bool mask_zero = false) => new Embedding(input_dim, output_dim, | |||||
| embeddings_initializer, | |||||
| mask_zero); | |||||
| public static InputLayer Input(int[] batch_shape = null, | |||||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||||
| string name = null, | |||||
| bool sparse = false, | |||||
| Tensor tensor = null) | |||||
| { | |||||
| var batch_size = batch_shape[0]; | |||||
| var shape = batch_shape.Skip(1).ToArray(); | |||||
| var input_layer = new InputLayer( | |||||
| input_shape: shape, | |||||
| batch_size: batch_size, | |||||
| name: name, | |||||
| dtype: dtype, | |||||
| sparse: sparse, | |||||
| input_tensor: tensor); | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -10,11 +10,16 @@ namespace Tensorflow.Keras.Engine | |||||
| protected bool _is_compiled; | protected bool _is_compiled; | ||||
| protected bool _expects_training_arg; | protected bool _expects_training_arg; | ||||
| protected bool _compute_output_and_mask_jointly; | protected bool _compute_output_and_mask_jointly; | ||||
| /// <summary> | |||||
| /// All layers in order of horizontal graph traversal. | |||||
| /// Entries are unique. Includes input and output layers. | |||||
| /// </summary> | |||||
| protected List<Layer> _layers; | |||||
| public Network(string name = null) | public Network(string name = null) | ||||
| : base(name: name) | : base(name: name) | ||||
| { | { | ||||
| _init_subclassed_network(name); | |||||
| } | } | ||||
| protected virtual void _init_subclassed_network(string name = null) | protected virtual void _init_subclassed_network(string name = null) | ||||
| @@ -30,6 +35,7 @@ namespace Tensorflow.Keras.Engine | |||||
| _expects_training_arg = false; | _expects_training_arg = false; | ||||
| _compute_output_and_mask_jointly = false; | _compute_output_and_mask_jointly = false; | ||||
| supports_masking = false; | supports_masking = false; | ||||
| _layers = new List<Layer>(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -23,6 +23,18 @@ namespace Tensorflow.Keras.Engine | |||||
| { | { | ||||
| built = false; | built = false; | ||||
| var set_inputs = false; | var set_inputs = false; | ||||
| if(_layers.Count == 0) | |||||
| { | |||||
| var (batch_shape, dtype) = (layer._batch_input_shape, layer._dtype); | |||||
| if(batch_shape != null) | |||||
| { | |||||
| // Instantiate an input layer. | |||||
| var x = keras.layers.Input( | |||||
| batch_shape: batch_shape, | |||||
| dtype: dtype, | |||||
| name: layer._name + "_input"); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| public void __exit__() | public void __exit__() | ||||
| @@ -12,7 +12,9 @@ namespace Tensorflow.Keras.Layers | |||||
| public Embedding(int input_dim, int output_dim, | public Embedding(int input_dim, int output_dim, | ||||
| IInitializer embeddings_initializer = null, | IInitializer embeddings_initializer = null, | ||||
| bool mask_zero = false) | |||||
| bool mask_zero = false, | |||||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||||
| int[] input_shape = null) : base(dtype: dtype, input_shape: input_shape) | |||||
| { | { | ||||
| this.input_dim = input_dim; | this.input_dim = input_dim; | ||||
| this.output_dim = output_dim; | this.output_dim = output_dim; | ||||
| @@ -0,0 +1,45 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.Layers | |||||
| { | |||||
| /// <summary> | |||||
| /// Layer to be used as an entry point into a Network (a graph of layers). | |||||
| /// </summary> | |||||
| public class InputLayer : Layer | |||||
| { | |||||
| public bool sparse; | |||||
| public int? batch_size; | |||||
| public InputLayer(int[] input_shape = null, | |||||
| int? batch_size = null, | |||||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||||
| string name = null, | |||||
| bool sparse = false, | |||||
| Tensor input_tensor = null) | |||||
| { | |||||
| built = true; | |||||
| this.sparse = sparse; | |||||
| this.batch_size = batch_size; | |||||
| this.supports_masking = true; | |||||
| if(input_tensor == null) | |||||
| { | |||||
| var batch_input_shape = new int[] { batch_size.HasValue ? batch_size.Value : -1, -1 }; | |||||
| if (sparse) | |||||
| { | |||||
| throw new NotImplementedException("InputLayer sparse is true"); | |||||
| } | |||||
| else | |||||
| { | |||||
| input_tensor = backend.placeholder( | |||||
| shape: batch_input_shape, | |||||
| dtype: dtype, | |||||
| name: name); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Layers | |||||
| /// </summary> | /// </summary> | ||||
| protected bool built; | protected bool built; | ||||
| protected bool trainable; | protected bool trainable; | ||||
| protected TF_DataType _dtype; | |||||
| public TF_DataType _dtype; | |||||
| /// <summary> | /// <summary> | ||||
| /// A stateful layer is a layer whose updates are run during inference too, | /// A stateful layer is a layer whose updates are run during inference too, | ||||
| /// for instance stateful RNNs. | /// for instance stateful RNNs. | ||||
| @@ -33,12 +33,16 @@ namespace Tensorflow.Keras.Layers | |||||
| protected InputSpec input_spec; | protected InputSpec input_spec; | ||||
| protected bool supports_masking; | protected bool supports_masking; | ||||
| protected List<RefVariable> _trainable_weights; | protected List<RefVariable> _trainable_weights; | ||||
| protected string _name; | |||||
| public string _name; | |||||
| protected string _base_name; | protected string _base_name; | ||||
| protected bool _compute_previous_mask; | protected bool _compute_previous_mask; | ||||
| protected List<Operation> _updates; | protected List<Operation> _updates; | ||||
| public int[] _batch_input_shape; | |||||
| public Layer(bool trainable = true, string name = null, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
| public Layer(bool trainable = true, | |||||
| string name = null, | |||||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||||
| int[] input_shape = null) | |||||
| { | { | ||||
| this.trainable = trainable; | this.trainable = trainable; | ||||
| this._dtype = dtype; | this._dtype = dtype; | ||||
| @@ -49,6 +53,12 @@ namespace Tensorflow.Keras.Layers | |||||
| _trainable_weights = new List<RefVariable>(); | _trainable_weights = new List<RefVariable>(); | ||||
| _compute_previous_mask = false; | _compute_previous_mask = false; | ||||
| _updates = new List<Operation>(); | _updates = new List<Operation>(); | ||||
| // Manage input shape information if passed. | |||||
| _batch_input_shape = new int[] { -1, -1 }; | |||||
| _dtype = dtype; | |||||
| } | } | ||||
| public Tensor __call__(Tensor inputs, | public Tensor __call__(Tensor inputs, | ||||
| @@ -11,6 +11,22 @@ namespace Tensorflow.Keras | |||||
| } | } | ||||
| public static Tensor placeholder(int[] shape = null, | |||||
| int ndim = -1, | |||||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||||
| bool sparse = false, | |||||
| string name = null) | |||||
| { | |||||
| if(sparse) | |||||
| { | |||||
| throw new NotImplementedException("placeholder sparse is true"); | |||||
| } | |||||
| else | |||||
| { | |||||
| return gen_array_ops.placeholder(dtype: dtype, shape: new TensorShape(shape), name: name); | |||||
| } | |||||
| } | |||||
| public static Graph get_graph() | public static Graph get_graph() | ||||
| { | { | ||||
| return ops.get_default_graph(); | return ops.get_default_graph(); | ||||