diff --git a/src/TensorFlowNET.Core/GlobalUsing.cs b/src/TensorFlowNET.Core/GlobalUsing.cs
index 2fd5b437..209bc291 100644
--- a/src/TensorFlowNET.Core/GlobalUsing.cs
+++ b/src/TensorFlowNET.Core/GlobalUsing.cs
@@ -3,4 +3,6 @@ global using System.Collections.Generic;
global using System.Text;
global using System.Collections;
global using System.Data;
-global using System.Linq;
\ No newline at end of file
+global using System.Linq;
+global using Tensorflow.Keras.Engine;
+global using Tensorflow.Framework.Models;
\ No newline at end of file
diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
index 9bc99701..b48cd553 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
@@ -1,5 +1,6 @@
using System;
using Tensorflow.Framework.Models;
+using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers.Rnn;
using Tensorflow.NumPy;
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types;
@@ -135,7 +136,7 @@ namespace Tensorflow.Keras.Layers
public ILayer GlobalMaxPooling1D(string data_format = "channels_last");
public ILayer GlobalMaxPooling2D(string data_format = "channels_last");
- public Tensors Input(Shape shape = null,
+ public KerasTensor Input(Shape shape = null,
int batch_size = -1,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,
diff --git a/src/TensorFlowNET.Core/Tensors/KerasTensor.cs b/src/TensorFlowNET.Core/Tensors/KerasTensor.cs
new file mode 100644
index 00000000..1034dcc8
--- /dev/null
+++ b/src/TensorFlowNET.Core/Tensors/KerasTensor.cs
@@ -0,0 +1,40 @@
+namespace Tensorflow.Keras.Engine;
+
+///
+/// A representation of a Keras in/output during Functional API construction.
+///
+public class KerasTensor
+{
+ private Tensor _tensor;
+ public void SetTensor(Tensors tensor)
+ => _tensor = tensor;
+
+ private TensorSpec _type_spec;
+ private string _name;
+
+ public KerasTensor(TensorSpec type_spec, string name = null)
+ {
+ _type_spec = type_spec;
+ _name = name;
+ }
+
+ public static KerasTensor from_tensor(Tensor tensor)
+ {
+ var type_spec = tensor.ToTensorSpec();
+ var kt = new KerasTensor(type_spec, name: tensor.name);
+ kt.SetTensor(tensor);
+ return kt;
+ }
+
+ public static implicit operator Tensors(KerasTensor kt)
+ => kt._tensor;
+
+ public static implicit operator Tensor(KerasTensor kt)
+ => kt._tensor;
+
+ public static implicit operator KerasTensor(Tensor tensor)
+ => from_tensor(tensor);
+
+ public static implicit operator KerasTensor(Tensors tensors)
+ => from_tensor(tensors.First());
+}
diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs
index 364800ae..574cf599 100644
--- a/src/TensorFlowNET.Keras/BackendImpl.cs
+++ b/src/TensorFlowNET.Keras/BackendImpl.cs
@@ -76,7 +76,7 @@ namespace Tensorflow.Keras
_GRAPH_VARIABLES[graph.graph_key] = v;
}
- public Tensor placeholder(Shape shape = null,
+ public KerasTensor placeholder(Shape shape = null,
int ndim = -1,
TF_DataType dtype = TF_DataType.DtInvalid,
bool sparse = false,
diff --git a/src/TensorFlowNET.Keras/GlobalUsing.cs b/src/TensorFlowNET.Keras/GlobalUsing.cs
index bc0798ed..85cd9194 100644
--- a/src/TensorFlowNET.Keras/GlobalUsing.cs
+++ b/src/TensorFlowNET.Keras/GlobalUsing.cs
@@ -4,4 +4,5 @@ global using System.Text;
global using System.Linq;
global using static Tensorflow.Binding;
global using static Tensorflow.KerasApi;
-global using Tensorflow.NumPy;
\ No newline at end of file
+global using Tensorflow.NumPy;
+global using Tensorflow.Keras.Engine;
\ No newline at end of file
diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
index 213b53a8..5968461d 100644
--- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs
+++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
@@ -466,7 +466,7 @@ namespace Tensorflow.Keras.Layers
/// In this case, values of 'None' in the 'shape' argument represent ragged dimensions. For more information about RaggedTensors, see this guide.
///
/// A tensor.
- public Tensors Input(Shape shape = null,
+ public KerasTensor Input(Shape shape = null,
int batch_size = -1,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,