diff --git a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs index 9b42eaaa..74432b2b 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs @@ -139,14 +139,14 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) + protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) { Tensor outputs = null; if (fused) { outputs = _fused_batch_norm(inputs, training: training); - return (outputs, outputs); + return new[] { outputs, outputs }; } throw new NotImplementedException("BatchNormalization call"); diff --git a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs index ad233d6b..7f763fb8 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs @@ -108,7 +108,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) + protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) { var outputs = _convolution_op.__call__(inputs, kernel); if (use_bias) @@ -126,7 +126,7 @@ namespace Tensorflow.Keras.Layers if (activation != null) outputs = activation.Activate(outputs); - return (outputs, outputs); + return new[] { outputs, outputs }; } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs index 74778873..bfd6f2a5 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs @@ -72,7 +72,7 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) + protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) { Tensor outputs = null; var rank = inputs.rank; @@ -90,7 +90,7 @@ namespace Tensorflow.Keras.Layers if (activation != null) outputs = activation.Activate(outputs); - return (outputs, outputs); + return new[] { outputs, outputs }; } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs index 95544d36..89ad4a63 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Embedding.cs @@ -50,14 +50,14 @@ namespace Tensorflow.Keras.Layers built = true; } - protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) + protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) { var dtype = inputs.dtype; if (dtype != tf.int32 && dtype != tf.int64) inputs = math_ops.cast(inputs, tf.int32); var @out = embedding_ops.embedding_lookup(embeddings, inputs); - return (@out, @out); + return new[] { @out, @out }; } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs index d7d7e31a..3ab37a0b 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs @@ -103,14 +103,14 @@ namespace Tensorflow.Keras.Layers _inbound_nodes = new List(); } - public (Tensor, Tensor) __call__(Tensor[] inputs, + public Tensor[] __call__(Tensor[] inputs, Tensor training = null, Tensor state = null, VariableScope scope = null) { var input_list = inputs; var input = inputs[0]; - Tensor outputs = null; + Tensor[] outputs = null; // We will attempt to build a TF graph if & only if all inputs are symbolic. // This is always the case in graph mode. It can also be the case in eager @@ -142,25 +142,26 @@ namespace Tensorflow.Keras.Layers // overridden). _maybe_build(inputs[0]); - (input, outputs) = call(inputs[0], + outputs = call(inputs[0], training: training, state: state); + (input, outputs) = _set_connectivity_metadata_(input, outputs); _handle_activity_regularization(inputs[0], outputs); _set_mask_metadata(inputs[0], outputs, null); }); } - return (input, outputs); + return outputs; } - private (Tensor, Tensor) _set_connectivity_metadata_(Tensor inputs, Tensor outputs) + private (Tensor, Tensor[]) _set_connectivity_metadata_(Tensor inputs, Tensor[] outputs) { //_add_inbound_node(input_tensors: inputs, output_tensors: outputs); return (inputs, outputs); } - private void _handle_activity_regularization(Tensor inputs, Tensor outputs) + private void _handle_activity_regularization(Tensor inputs, Tensor[] outputs) { //if(_activity_regularizer != null) { @@ -168,7 +169,7 @@ namespace Tensorflow.Keras.Layers } } - private void _set_mask_metadata(Tensor inputs, Tensor outputs, Tensor previous_mask) + private void _set_mask_metadata(Tensor inputs, Tensor[] outputs, Tensor previous_mask) { } @@ -178,9 +179,9 @@ namespace Tensorflow.Keras.Layers return null; } - protected virtual (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) + protected virtual Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) { - return (inputs, inputs); + throw new NotImplementedException(""); } protected virtual string _name_scope() diff --git a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs index 81d57abe..ccb1cd6f 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs @@ -43,7 +43,7 @@ namespace Tensorflow.Keras.Layers this.input_spec = new InputSpec(ndim: 4); } - protected override (Tensor, Tensor) call(Tensor inputs, Tensor training = null, Tensor state = null) + protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) { int[] pool_shape; if (data_format == "channels_last") @@ -64,7 +64,7 @@ namespace Tensorflow.Keras.Layers padding: padding.ToUpper(), data_format: conv_utils.convert_data_format(data_format, 4)); - return (outputs, outputs); + return new[] { outputs, outputs }; } } }