diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs b/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs index dcbfa1e6..ceb3afa4 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs @@ -1,4 +1,5 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; namespace Tensorflow.Keras.Engine { @@ -11,5 +12,8 @@ namespace Tensorflow.Keras.Engine { _layers.AddRange(layers); } + + public virtual TensorShape ComputeOutputShape(TensorShape input_shape) + => throw new NotImplementedException(""); } } diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs index b358c719..28c7be3e 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs @@ -26,11 +26,11 @@ namespace Tensorflow.Keras.Layers var result = array_ops.reshape(inputs, shape.ToArray()); if (!tf.Context.executing_eagerly()) - result.set_shape(compute_output_shape(inputs.shape)); + result.set_shape(ComputeOutputShape(inputs.shape)); return result; } - TensorShape compute_output_shape(TensorShape input_shape) + public override TensorShape ComputeOutputShape(TensorShape input_shape) { if (input_shape.dims[0] == -1) {