From de2383162bd50c274381dba77446bea78fc736a2 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Wed, 5 Aug 2020 06:25:38 -0500 Subject: [PATCH] tf.keras.layers #355 --- .../Gradients/GradientTape.cs | 6 +- .../Gradients/Tape.ComputeGradient.cs | 20 +++- .../Activations.Linear.cs} | 6 +- .../Keras/Activations/Activations.Relu.cs | 26 +++++ .../Keras/Activations/Activations.cs | 9 ++ .../Keras/ArgsDefinition/NodeArgs.cs | 4 +- .../Keras/Engine/Layer.Layers.cs | 29 +++++ src/TensorFlowNET.Core/Keras/Engine/Layer.cs | 107 +++++++----------- src/TensorFlowNET.Core/Keras/Engine/Node.cs | 6 +- .../Keras/Engine/Sequential.cs | 13 ++- src/TensorFlowNET.Core/Keras/KerasApi.cs | 30 ++++- .../Keras/Layers/BatchNormalization.cs | 6 +- src/TensorFlowNET.Core/Keras/Layers/Conv.cs | 6 +- src/TensorFlowNET.Core/Keras/Layers/Dense.cs | 8 +- .../Keras/Layers/Embedding.cs | 10 +- .../Keras/Layers/InputLayer.cs | 4 +- .../Keras/Layers/Pooling2D.cs | 6 +- src/TensorFlowNET.Core/Keras/Models.cs | 13 +++ .../Keras/Optimizers/OptimizerV2.cs | 10 +- .../Keras/Optimizers/SGD.cs | 2 +- src/TensorFlowNET.Core/Layers/Layer.cs | 5 +- .../Operations/NnOps/BasicLSTMCell.cs | 8 +- .../Operations/NnOps/BasicRNNCell.cs | 6 +- .../Operations/NnOps/gen_nn_ops.cs | 34 +++++- .../Operations/array_ops.cs | 41 ++++++- .../Operations/gen_array_ops.cs | 32 ------ .../Variables/BaseResourceVariable.cs | 3 +- .../Variables/IVariableV1.cs | 1 + .../Variables/RefVariable.cs | 1 + src/TensorFlowNET.Core/tensorflow.cs | 1 - .../Keras/LayersTest.cs | 7 ++ 31 files changed, 297 insertions(+), 163 deletions(-) rename src/TensorFlowNET.Core/Keras/{Activations.cs => Activations/Activations.Linear.cs} (67%) create mode 100644 src/TensorFlowNET.Core/Keras/Activations/Activations.Relu.cs create mode 100644 src/TensorFlowNET.Core/Keras/Activations/Activations.cs create mode 100644 src/TensorFlowNET.Core/Keras/Engine/Layer.Layers.cs create mode 100644 src/TensorFlowNET.Core/Keras/Models.cs diff --git a/src/TensorFlowNET.Core/Gradients/GradientTape.cs b/src/TensorFlowNET.Core/Gradients/GradientTape.cs index a8c0d8fd..dccb9574 100644 --- a/src/TensorFlowNET.Core/Gradients/GradientTape.cs +++ b/src/TensorFlowNET.Core/Gradients/GradientTape.cs @@ -107,19 +107,19 @@ namespace Tensorflow.Gradients public Tensor gradient(Tensor target, ResourceVariable source) { - var results = gradient(target, new[] { source }); + var results = gradient(target, new List { source }); return results[0]; } public (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources) { - var results = gradient(target, new[] { sources.Item1, sources.Item2 }); + var results = gradient(target, new List { sources.Item1, sources.Item2 }); return (results[0], results[1]); } - public Tensor[] gradient(Tensor target, IEnumerable sources) + public Tensor[] gradient(Tensor target, List sources) { if (_recording) { diff --git a/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs b/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs index 94e0d3ee..770b75ca 100644 --- a/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs +++ b/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs @@ -54,7 +54,16 @@ namespace Tensorflow.Gradients var id = trace.output_tensor_info[i].GetID(); if (!gradients.find(id, out var grad_it)) { - throw new NotImplementedException("FunctionsAcceptingNoneForIndicesMap"); + if (FunctionsAcceptingNoneForIndicesMap().find(trace.op_type, out var func_name_it) && + func_name_it.find(i)) + { + out_gradients.Add(null); + } + else + { + out_gradients.Add(null); + zero_indices.Add(i); + } } else { @@ -184,6 +193,15 @@ namespace Tensorflow.Gradients return result.ToArray(); } + UnorderedMap> FunctionsAcceptingNoneForIndicesMap() + { + var m = new UnorderedMap>(); + m.Add("SoftmaxCrossEntropyWithLogits", new UnorderedSet(new[] { 1 })); + m.Add("SparseSoftmaxCrossEntropyWithLogits", new UnorderedSet(new[] { 1 })); + m.Add("FusedBatchNorm", new UnorderedSet(new[] { 1, 2, 3, 4 })); + return m; + } + UnorderedMapEnumerable> InitialGradients(long[] target_tensor_ids, UnorderedMap sources_that_are_targets, Tensor[] output_gradients, diff --git a/src/TensorFlowNET.Core/Keras/Activations.cs b/src/TensorFlowNET.Core/Keras/Activations/Activations.Linear.cs similarity index 67% rename from src/TensorFlowNET.Core/Keras/Activations.cs rename to src/TensorFlowNET.Core/Keras/Activations/Activations.Linear.cs index 77a83fbc..fd1ce7ab 100644 --- a/src/TensorFlowNET.Core/Keras/Activations.cs +++ b/src/TensorFlowNET.Core/Keras/Activations/Activations.Linear.cs @@ -5,13 +5,11 @@ using static Tensorflow.Binding; namespace Tensorflow.Keras { - public delegate Tensor Activation(Tensor x); - - public class Activations + public partial class Activations { /// /// Linear activation function (pass-through). /// - public Activation Linear = x => x; + public Activation Linear = (features, name) => features; } } diff --git a/src/TensorFlowNET.Core/Keras/Activations/Activations.Relu.cs b/src/TensorFlowNET.Core/Keras/Activations/Activations.Relu.cs new file mode 100644 index 00000000..3958f702 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Activations/Activations.Relu.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Operations; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras +{ + public partial class Activations + { + public Activation Relu = (features, name) => + { + if (tf.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "Relu", name, + null, + features); + + return results[0]; + } + + throw new NotImplementedException(""); + }; + } +} diff --git a/src/TensorFlowNET.Core/Keras/Activations/Activations.cs b/src/TensorFlowNET.Core/Keras/Activations/Activations.cs new file mode 100644 index 00000000..ad4d8d59 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Activations/Activations.cs @@ -0,0 +1,9 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras +{ + public delegate Tensor Activation(Tensor features, string name = null); +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs index 5e38da99..0dd4355f 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs @@ -11,7 +11,7 @@ namespace Tensorflow.Keras.ArgsDefinition public Layer[] InboundLayers { get; set; } public int[] NodeIndices { get; set; } public int[] TensorIndices { get; set; } - public Tensor[] InputTensors { get; set; } - public Tensor[] Outputs { get; set; } + public Tensor InputTensors { get; set; } + public Tensor Outputs { get; set; } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.Layers.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.Layers.cs new file mode 100644 index 00000000..14f3f79a --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.Layers.cs @@ -0,0 +1,29 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Layers; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Engine +{ + public partial class Layer + { + protected List _layers = new List(); + + protected Layer Dense(int units, + Activation activation = null, + TensorShape input_shape = null) + { + var layer = new Dense(new DenseArgs + { + Units = units, + Activation = activation ?? tf.keras.activations.Linear, + InputShape = input_shape + }); + + _layers.Add(layer); + return layer; + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs index fd83ae7e..8964bb0a 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs @@ -18,11 +18,9 @@ using System; using System.Collections.Generic; using System.Linq; using System.Threading; -using Tensorflow.Contexts; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Layers; using Tensorflow.Keras.Utils; -using Tensorflow.Operations.Activation; using Tensorflow.Train; using static Tensorflow.Binding; @@ -34,7 +32,7 @@ namespace Tensorflow.Keras.Engine /// as convolution, batch norm, etc. These operations require managing weights, /// losses, updates, and inter-layer connectivity. /// - public abstract class Layer : AutoTrackable + public abstract partial class Layer : AutoTrackable { /// /// Arguments initialize layer. @@ -60,8 +58,19 @@ namespace Tensorflow.Keras.Engine protected InputSpec inputSpec; public bool SupportsMasking { get; set; } protected List trainableWeights; - public List TrainableVariables => trainableWeights; + public List trainable_variables + { + get + { + if(trainableWeights.Count == 0) + _layers.ForEach(x => trainableWeights.AddRange(x.trainableWeights)); + + return trainableWeights; + } + } + protected List nonTrainableWeights; + public List non_trainable_variables => nonTrainableWeights; string name; public string Name => name; @@ -112,20 +121,20 @@ namespace Tensorflow.Keras.Engine /// /// /// - public Tensor[] Apply(Tensor[] inputs, bool is_training = false) + public Tensor Apply(Tensor inputs, bool is_training = false) { - var input = inputs[0]; - Tensor[] outputs = null; + Tensor outputs = null; callContext = callContext ?? new ThreadLocal() { Value = new CallContext() }; + var eager = tf.executing_eagerly(); using var ctxManager = CallContext.enter(); string nameScope = ""; - if (tf.executing_eagerly()) + if (eager) { nameScope = name; } @@ -134,7 +143,7 @@ namespace Tensorflow.Keras.Engine throw new NotImplementedException(""); } - using var graph = tf.keras.backend.get_graph().as_default(); + // using var graph = tf.keras.backend.get_graph().as_default(); tf_with(ops.name_scope(nameScope), scope => { @@ -143,74 +152,36 @@ namespace Tensorflow.Keras.Engine outputs = call(inputs, is_training: is_training); - (input, outputs) = _set_connectivity_metadata_(input, outputs); - _handle_activity_regularization(inputs[0], outputs); - _set_mask_metadata(inputs[0], outputs, null); + outputs = _set_connectivity_metadata_(inputs, outputs); + _handle_activity_regularization(inputs, outputs); + _set_mask_metadata(inputs, outputs, null); }); return outputs; } - [Obsolete("User Apply()")] - public Tensor[] __call__(Tensor[] inputs, - Tensor training = null, - Tensor state = null, - VariableScope scope = null) + private Tensor _set_connectivity_metadata_(Tensor inputs, Tensor outputs) { - var input_list = inputs; - var input = inputs[0]; - Tensor[] outputs = null; - - // We will attempt to build a TF graph if & only if all inputs are symbolic. - // This is always the case in graph mode. It can also be the case in eager - // mode when all inputs can be traced back to `keras.Input()` (when building - // models using the functional API). - bool build_graph = tf_utils.are_all_symbolic_tensors(input_list); - - if (build_graph) - { - // Only create Keras history if at least one tensor originates from a - // `keras.Input`. Otherwise this Layer may be being used outside the Keras - // framework. - // base_layer_utils.create_keras_history(inputs) - } - - // with base_layer_utils.call_context(self): - - // Handle Keras mask propagation from previous layer to current layer. - // with base_layer_utils.call_context(self): - // Check input assumptions set after layer building, e.g. input shape. - if (build_graph) + /*var returnOutputs = new List(); + foreach(var x in outputs) { - // Symbolic execution on symbolic tensors. We will attempt to build - // the corresponding TF subgraph inside `backend.get_graph()` - var graph = tf.keras.backend.get_graph().as_default(); - tf_with(ops.name_scope(_name_scope()), delegate + if (inputs.Contains(x)) { - // Build layer if applicable (if the `build` method has been - // overridden). - MaybeBuild(inputs); - - outputs = call(inputs, - // training: training, - state: state); - (input, outputs) = _set_connectivity_metadata_(input, outputs); - _handle_activity_regularization(inputs[0], outputs); - _set_mask_metadata(inputs[0], outputs, null); - }); - } + } + returnOutputs.Add(x); + }*/ - return outputs; - } + new Node(this, new NodeArgs + { + Outputs = outputs + }); - private (Tensor, Tensor[]) _set_connectivity_metadata_(Tensor inputs, Tensor[] outputs) - { //_add_inbound_node(input_tensors: inputs, output_tensors: outputs); - return (inputs, outputs); + return outputs; } - private void _handle_activity_regularization(Tensor inputs, Tensor[] outputs) + private void _handle_activity_regularization(Tensor inputs, Tensor outputs) { //if(_activity_regularizer != null) { @@ -218,7 +189,7 @@ namespace Tensorflow.Keras.Engine } } - private void _set_mask_metadata(Tensor inputs, Tensor[] outputs, Tensor previous_mask) + private void _set_mask_metadata(Tensor inputs, Tensor outputs, Tensor previous_mask) { } @@ -228,7 +199,7 @@ namespace Tensorflow.Keras.Engine return null; } - protected virtual Tensor[] call(Tensor[] inputs, bool is_training = false, Tensor state = null) + protected virtual Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) { throw new NotImplementedException(""); } @@ -238,15 +209,15 @@ namespace Tensorflow.Keras.Engine return Name; } - protected void MaybeBuild(Tensor[] inputs) + protected void MaybeBuild(Tensor inputs) { // Check input assumptions set before layer building, e.g. input rank. if (built) return; if (DType == TF_DataType.DtInvalid) - args.DType = inputs[0].dtype; + args.DType = inputs.dtype; - var input_shapes = inputs[0].TensorShape; + var input_shapes = inputs.TensorShape; build(input_shapes); built = true; } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Node.cs b/src/TensorFlowNET.Core/Keras/Engine/Node.cs index 5ada8791..ee734588 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Node.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Node.cs @@ -35,13 +35,13 @@ namespace Tensorflow.Keras.Engine public int[] node_indices; public int[] tensor_indices; - public Tensor[] input_tensors; - public Tensor[] Outputs => args.Outputs; + public Tensor input_tensors; + public Tensor Outputs => args.Outputs; public TensorShape[] input_shapes; public TensorShape[] output_shapes; List kerasInputs; - public Node(InputLayer layer, NodeArgs args) + public Node(Layer layer, NodeArgs args) { this.args = args; diff --git a/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs b/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs index 3883b2c5..49e605c1 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs @@ -25,9 +25,7 @@ namespace Tensorflow.Keras.Engine #pragma warning disable CS0649 // Field 'Sequential._is_graph_network' is never assigned to, and will always have its default value false bool _is_graph_network; #pragma warning restore CS0649 // Field 'Sequential._is_graph_network' is never assigned to, and will always have its default value false -#pragma warning disable CS0169 // The field 'Sequential.outputs' is never used - Tensor[] outputs; -#pragma warning restore CS0169 // The field 'Sequential.outputs' is never used + Tensor outputs; bool computeOutputAndMaskJointly; bool autoTrackSubLayers; @@ -51,6 +49,11 @@ namespace Tensorflow.Keras.Engine } + public void add(Tensor layer) + { + + } + /// /// Adds a layer instance on top of the layer stack. /// @@ -71,7 +74,7 @@ namespace Tensorflow.Keras.Engine { // Instantiate an input layer. var x = tf.keras.Input( - batch_shape: layer.BatchInputShape, + shape: layer.BatchInputShape, dtype: layer.DType, name: layer.Name + "_input"); @@ -86,7 +89,7 @@ namespace Tensorflow.Keras.Engine if (set_inputs) { // If an input layer (placeholder) is available. - // outputs = layer.inbound_nodes; + outputs = layer.InboundNodes[^1].Outputs; } } diff --git a/src/TensorFlowNET.Core/Keras/KerasApi.cs b/src/TensorFlowNET.Core/Keras/KerasApi.cs index 982ec023..ef18d133 100644 --- a/src/TensorFlowNET.Core/Keras/KerasApi.cs +++ b/src/TensorFlowNET.Core/Keras/KerasApi.cs @@ -14,15 +14,35 @@ namespace Tensorflow { public KerasDataset datasets { get; } = new KerasDataset(); public Initializers initializers { get; } = new Initializers(); - public Layers layers { get; } = new Layers(); + public LayersApi layers { get; } = new LayersApi(); public Activations activations { get; } = new Activations(); public BackendImpl backend { get; } = new BackendImpl(); + public Models models { get; } = new Models(); + public Sequential Sequential() => new Sequential(); - public Tensor[] Input(int[] batch_shape = null, + /// + /// Instantiate a Keras tensor. + /// + /// + /// + /// + /// + /// + /// A boolean specifying whether the placeholder to be created is sparse. + /// + /// + /// A boolean specifying whether the placeholder to be created is ragged. + /// + /// + /// Optional existing tensor to wrap into the `Input` layer. + /// If set, the layer will not create a placeholder tensor. + /// + /// + public Tensor Input(TensorShape shape = null, int batch_size = -1, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, @@ -33,7 +53,7 @@ namespace Tensorflow var args = new InputLayerArgs { Name = name, - BatchInputShape = batch_shape, + InputShape = shape, BatchSize = batch_size, DType = dtype, Sparse = sparse, @@ -43,7 +63,7 @@ namespace Tensorflow var layer = new InputLayer(args); - return layer.InboundNodes[0].Outputs; + return layer.InboundNodes[0].Outputs[0]; } public static Embedding Embedding(int input_dim, @@ -55,7 +75,7 @@ namespace Tensorflow embeddings_initializer, mask_zero); - public class Layers + public class LayersApi { public Layer Dense(int units, Activation activation = null, diff --git a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs index 23992b56..c8298234 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs @@ -143,15 +143,15 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensor[] call(Tensor[] inputs, bool is_training = false, Tensor state = null) + protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) { Tensor outputs = null; if (fused) { Tensor training = tf.convert_to_tensor(is_training); - outputs = _fused_batch_norm(inputs[0], training: training); - return new[] { outputs, outputs }; + outputs = _fused_batch_norm(inputs, training: training); + return outputs; } throw new NotImplementedException("BatchNormalization call"); diff --git a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs index a1d48be1..fa3b7505 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs @@ -108,9 +108,9 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensor[] call(Tensor[] inputs, bool training = false, Tensor state = null) + protected override Tensor call(Tensor inputs, bool training = false, Tensor state = null) { - var outputs = _convolution_op.__call__(inputs[0], kernel); + var outputs = _convolution_op.__call__(inputs, kernel); if (use_bias) { if (data_format == "channels_first") @@ -126,7 +126,7 @@ namespace Tensorflow.Keras.Layers if (activation != null) outputs = activation.Activate(outputs); - return new[] { outputs, outputs }; + return outputs; } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs index c6485427..b6258aea 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs @@ -65,17 +65,17 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensor[] call(Tensor[] inputs, bool training = false, Tensor state = null) + protected override Tensor call(Tensor inputs, bool training = false, Tensor state = null) { Tensor outputs = null; - var rank = inputs[0].rank; + var rank = inputs.rank; if(rank > 2) { throw new NotImplementedException("call rank > 2"); } else { - outputs = gen_math_ops.mat_mul(inputs[0], kernel.AsTensor()); + outputs = gen_math_ops.mat_mul(inputs, kernel.AsTensor()); } if (args.UseBias) @@ -83,7 +83,7 @@ namespace Tensorflow.Keras.Layers if (args.Activation != null) outputs = activation(outputs); - return new[] { outputs }; + return outputs; } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs index 5080c425..cc38f553 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs @@ -57,14 +57,14 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override Tensor[] call(Tensor[] inputs, bool is_training = false, Tensor state = null) + protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) { - var dtype = inputs[0].dtype; + var dtype = inputs.dtype; if (dtype != tf.int32 && dtype != tf.int64) - inputs[0] = math_ops.cast(inputs[0], tf.int32); + inputs = math_ops.cast(inputs, tf.int32); - var @out = embedding_ops.embedding_lookup(embeddings, inputs[0]); - return new[] { @out, @out }; + var outputs = embedding_ops.embedding_lookup(embeddings, inputs[0]); + return outputs; } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs b/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs index 02473904..055ed373 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs @@ -84,8 +84,8 @@ namespace Tensorflow.Keras.Layers // input_tensor._keras_mask = None new Node(this, new NodeArgs { - InputTensors = new Tensor[] { args.InputTensor }, - Outputs = new Tensor[] { args.InputTensor } + InputTensors = args.InputTensor, + Outputs = args.InputTensor }); typeSpec = new TensorSpec(args.InputTensor.TensorShape, diff --git a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs index 6ee054bf..26f30885 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs @@ -45,7 +45,7 @@ namespace Tensorflow.Keras.Layers this.input_spec = new InputSpec(ndim: 4); } - protected override Tensor[] call(Tensor[] inputs, bool is_training = false, Tensor state = null) + protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) { int[] pool_shape; if (data_format == "channels_last") @@ -60,13 +60,13 @@ namespace Tensorflow.Keras.Layers } var outputs = pool_function.Apply( - inputs[0], + inputs, ksize: pool_shape, strides: strides, padding: padding.ToUpper(), data_format: conv_utils.convert_data_format(data_format, 4)); - return new[] { outputs, outputs }; + return outputs; } } } diff --git a/src/TensorFlowNET.Core/Keras/Models.cs b/src/TensorFlowNET.Core/Keras/Models.cs new file mode 100644 index 00000000..545c6143 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Models.cs @@ -0,0 +1,13 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras +{ + public class Models + { + public Sequential Sequential() + => new Sequential(); + } +} diff --git a/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs b/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs index 6d29f95d..6b926622 100644 --- a/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs +++ b/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs @@ -39,7 +39,7 @@ namespace Tensorflow.Keras.Optimizers public void apply_gradients((Tensor, ResourceVariable) grads_and_vars, string name = null, bool experimental_aggregate_gradients = true) - => apply_gradients(new (Tensor, ResourceVariable)[] { grads_and_vars }, + => apply_gradients(grads_and_vars, name: name, experimental_aggregate_gradients: experimental_aggregate_gradients); @@ -77,7 +77,7 @@ namespace Tensorflow.Keras.Optimizers _resource_apply_dense(var, grad, apply_state); } - protected virtual Operation _resource_apply_dense(ResourceVariable var, + protected virtual Operation _resource_apply_dense(IVariableV1 var, EagerTensor grad, Dictionary> _apply_state) { @@ -107,7 +107,7 @@ namespace Tensorflow.Keras.Optimizers return grads_and_vars.Select(x => x.Item1).ToArray(); } - Dictionary> _prepare(ResourceVariable[] var_list) + Dictionary> _prepare(IVariableV1[] var_list) { var _apply_state = new Dictionary>(); var keys = var_list.Select(x => new DeviceDType @@ -151,7 +151,7 @@ namespace Tensorflow.Keras.Optimizers return math_ops.cast(value, dtype); } - void _create_all_weights(ResourceVariable[] var_list) + void _create_all_weights(IVariableV1[] var_list) { if(_iterations == null) { @@ -190,7 +190,7 @@ namespace Tensorflow.Keras.Optimizers _hypers_created = true; } - void _create_slots(ResourceVariable[] var_list) + void _create_slots(IVariableV1[] var_list) { if(_momentum) { diff --git a/src/TensorFlowNET.Core/Keras/Optimizers/SGD.cs b/src/TensorFlowNET.Core/Keras/Optimizers/SGD.cs index 952f51cd..afedb391 100644 --- a/src/TensorFlowNET.Core/Keras/Optimizers/SGD.cs +++ b/src/TensorFlowNET.Core/Keras/Optimizers/SGD.cs @@ -40,7 +40,7 @@ namespace Tensorflow.Keras.Optimizers _get_hyper("momentum", device_dtype.DType)); } - protected override Operation _resource_apply_dense(ResourceVariable var, EagerTensor grad, Dictionary> _apply_state) + protected override Operation _resource_apply_dense(IVariableV1 var, EagerTensor grad, Dictionary> _apply_state) { if (_momentum) { diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index 48003b94..ae3157b0 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -88,9 +88,8 @@ namespace Tensorflow.Layers { _current_scope = scope2; // Actually call layer - outputs = base.__call__(new Tensor[] { inputs }, - training: training, - state: state); + /*outputs = base.Apply(new Tensor[] { inputs }, + is_training: training);*/ }); diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs index a0fbc007..35d1a026 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs @@ -74,7 +74,7 @@ namespace Tensorflow /// /// /// - protected override Tensor[] call(Tensor[] inputs, bool is_training = false, Tensor state = null) + protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) { var one = constant_op.constant(1, dtype: dtypes.int32); // Parameters of gates are concatenated into one multiply for efficiency. @@ -87,7 +87,7 @@ namespace Tensorflow // array_ops.split(value: state, num_or_size_splits: 2, axis: one); throw new NotImplementedException("BasicLstmCell call"); } - var gate_inputs = math_ops.matmul(array_ops.concat(new[] { inputs[0], h }, 1), _kernel as RefVariable); + var gate_inputs = math_ops.matmul(array_ops.concat(new[] { inputs, h }, 1), _kernel as RefVariable); gate_inputs = nn_ops.bias_add(gate_inputs, _bias as RefVariable); // i = input_gate, j = new_input, f = forget_gate, o = output_gate @@ -105,9 +105,9 @@ namespace Tensorflow if (_state_is_tuple) - return new[] { new_c, new_h }; + return new_c; else - return new[] { array_ops.concat(new[] { new_c, new_h }, 1) }; + return array_ops.concat(new[] { new_c, new_h }, 1); } public override object get_initial_state(Tensor inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs index de8e7b95..55589e64 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs @@ -67,14 +67,14 @@ namespace Tensorflow built = true; } - protected override Tensor[] call(Tensor[] inputs, bool is_training = false, Tensor state = null) + protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) { // Most basic RNN: output = new_state = act(W * input + U * state + B). - var concat = array_ops.concat(new[] { inputs[0], state }, 1); + var concat = array_ops.concat(new[] { inputs, state }, 1); var gate_inputs = math_ops.matmul(concat, _kernel as RefVariable); gate_inputs = nn_ops.bias_add(gate_inputs, _bias as RefVariable); var output = _activation(gate_inputs, null); - return new[] { output, output }; + return output; } } } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs index ccc83864..31b06a32 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs @@ -108,6 +108,17 @@ namespace Tensorflow.Operations string data_format = null, string name = null) { + if (tf.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "BiasAdd", name, + null, + value, bias, + "data_format", data_format); + + return results[0]; + } + if (data_format == null) data_format = "NHWC"; @@ -125,6 +136,17 @@ namespace Tensorflow.Operations string data_format = "NHWC", string name = null) { + if (tf.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "BiasAddGrad", name, + null, + out_backprop, + "data_format", data_format); + + return results[0]; + } + if (data_format == null) data_format = "NHWC"; @@ -460,6 +482,16 @@ namespace Tensorflow.Operations /// public static (Tensor loss, Tensor backprop) sparse_softmax_cross_entropy_with_logits(Tensor features, Tensor labels, string name = "SparseSoftmaxCrossEntropyWithLogits") { + if (tf.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "SparseSoftmaxCrossEntropyWithLogits", name, + null, + features, labels); + + return (results[0], results[1]); + } + var op = tf.OpDefLib._apply_op_helper("SparseSoftmaxCrossEntropyWithLogits", name: name, args: new { features, labels }); int _idx = 0; var loss = op.outputs[_idx++]; @@ -475,7 +507,7 @@ namespace Tensorflow.Operations /// A `Tensor`. Has the same type as `features`. public static Tensor relu(Tensor features, string name = null) { - if (tf.Context.executing_eagerly()) + if (tf.executing_eagerly()) { var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, "Relu", name, diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index cd38d4f8..8ee4d91d 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -31,8 +31,47 @@ namespace Tensorflow public static Tensor placeholder_with_default(T input, int[] shape, string name = null) => gen_array_ops.placeholder_with_default(input, shape, name); + /// + /// An identity op that triggers an error if a gradient is requested. + /// + /// + /// any tensor. + /// + /// + /// If specified, the created operation in the graph will be this one, otherwise it will be named 'PreventGradient'. + /// + /// + /// Will be printed in the error when anyone tries to differentiate + /// this operation. + /// + /// + /// the same input tensor. + /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. + /// + /// + /// When executed in a graph, this op outputs its input tensor as-is. + /// + /// When building ops to compute gradients, the TensorFlow gradient system + /// will return an error when trying to lookup the gradient of this op, + /// because no gradient must ever be registered for this function. This + /// op exists to prevent subtle bugs from silently returning unimplemented + /// gradients in some corner cases. + /// public static Tensor prevent_gradient(Tensor input, string message = "", string name = null) - => gen_array_ops.prevent_gradient(input, message: message, name: name); + { + if (tf.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "PreventGradient", name, + null, + input, + "message", message); + return results[0]; + } + + var op = tf.OpDefLib._apply_op_helper("PreventGradient", name: name, args: new { input, message }); + return op.output; + } internal static Tensor constant(object value, TF_DataType dtype = TF_DataType.DtInvalid, diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 21c48c4a..4ce25b3c 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -186,38 +186,6 @@ namespace Tensorflow return _op.output; } - /// - /// An identity op that triggers an error if a gradient is requested. - /// - /// - /// any tensor. - /// - /// - /// If specified, the created operation in the graph will be this one, otherwise it will be named 'PreventGradient'. - /// - /// - /// Will be printed in the error when anyone tries to differentiate - /// this operation. - /// - /// - /// the same input tensor. - /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. - /// - /// - /// When executed in a graph, this op outputs its input tensor as-is. - /// - /// When building ops to compute gradients, the TensorFlow gradient system - /// will return an error when trying to lookup the gradient of this op, - /// because no gradient must ever be registered for this function. This - /// op exists to prevent subtle bugs from silently returning unimplemented - /// gradients in some corner cases. - /// - public static Tensor prevent_gradient(Tensor input, string message = "", string name = null) - { - var op = tf.OpDefLib._apply_op_helper("PreventGradient", name: name, args: new { input, message }); - return op.output; - } - /// /// Return a tensor with the same shape and contents as the input tensor or value. /// diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index 2dd55e02..c8e24528 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -45,6 +45,7 @@ namespace Tensorflow public Operation Initializer => initializer_op; public Operation Op => handle.op; public Graph Graph => handle.graph; + public string Device => ""; public BaseResourceVariable() { @@ -148,6 +149,6 @@ namespace Tensorflow { } - public Tensor AsTensor() => _graph_element; + public Tensor AsTensor() => read_value(); } } diff --git a/src/TensorFlowNET.Core/Variables/IVariableV1.cs b/src/TensorFlowNET.Core/Variables/IVariableV1.cs index 68a1b78a..6295a1cd 100644 --- a/src/TensorFlowNET.Core/Variables/IVariableV1.cs +++ b/src/TensorFlowNET.Core/Variables/IVariableV1.cs @@ -33,6 +33,7 @@ namespace Tensorflow { public string Name { get; } public Tensor Handle { get; } + public string Device { get; } public Operation Initializer { get; } public Operation Op { get; } public Tensor GraphElement { get; } diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 47392cdc..3fccc04e 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -48,6 +48,7 @@ namespace Tensorflow public TF_DataType dtype => _variable.dtype; public TensorShape shape => tensor_util.to_shape(_variable.shape); + public string Device => ""; public string Name => _variable.name; diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index ccf4033a..368c66ba 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -51,7 +51,6 @@ namespace Tensorflow { Status = new Status(); Context = new Context(new ContextOptions(), Status); - enable_eager_execution(); OpDefLib = new OpDefLibrary(); ConstructThreadingObjects(); InitGradientEnvironment(); diff --git a/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs b/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs index 6d0ae4bc..88d5de78 100644 --- a/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs +++ b/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs @@ -16,6 +16,13 @@ namespace TensorFlowNET.UnitTest.Keras [TestClass, Ignore] public class LayersTest : GraphModeTestBase { + [TestMethod] + public void Sequential() + { + var model = tf.keras.models.Sequential(); + model.add(tf.keras.Input(shape: 16)); + } + [TestMethod] public void Embedding() {