From 70f873eccef99e4ca6af39a8ac798cc36292ace2 Mon Sep 17 00:00:00 2001 From: Haiping Chen Date: Mon, 10 Jul 2023 15:02:39 -0500 Subject: [PATCH] Initially adding KerasTensor. #1142 --- src/TensorFlowNET.Core/GlobalUsing.cs | 4 +- .../Keras/Layers/ILayersApi.cs | 3 +- src/TensorFlowNET.Core/Tensors/KerasTensor.cs | 40 +++++++++++++++++++ src/TensorFlowNET.Keras/BackendImpl.cs | 2 +- src/TensorFlowNET.Keras/GlobalUsing.cs | 3 +- src/TensorFlowNET.Keras/Layers/LayersApi.cs | 2 +- 6 files changed, 49 insertions(+), 5 deletions(-) create mode 100644 src/TensorFlowNET.Core/Tensors/KerasTensor.cs diff --git a/src/TensorFlowNET.Core/GlobalUsing.cs b/src/TensorFlowNET.Core/GlobalUsing.cs index 2fd5b437..209bc291 100644 --- a/src/TensorFlowNET.Core/GlobalUsing.cs +++ b/src/TensorFlowNET.Core/GlobalUsing.cs @@ -3,4 +3,6 @@ global using System.Collections.Generic; global using System.Text; global using System.Collections; global using System.Data; -global using System.Linq; \ No newline at end of file +global using System.Linq; +global using Tensorflow.Keras.Engine; +global using Tensorflow.Framework.Models; \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs index 9bc99701..b48cd553 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs @@ -1,5 +1,6 @@ using System; using Tensorflow.Framework.Models; +using Tensorflow.Keras.Engine; using Tensorflow.Keras.Layers.Rnn; using Tensorflow.NumPy; using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; @@ -135,7 +136,7 @@ namespace Tensorflow.Keras.Layers public ILayer GlobalMaxPooling1D(string data_format = "channels_last"); public ILayer GlobalMaxPooling2D(string data_format = "channels_last"); - public Tensors Input(Shape shape = null, + public KerasTensor Input(Shape shape = null, int batch_size = -1, string name = null, TF_DataType dtype = TF_DataType.DtInvalid, diff --git a/src/TensorFlowNET.Core/Tensors/KerasTensor.cs b/src/TensorFlowNET.Core/Tensors/KerasTensor.cs new file mode 100644 index 00000000..1034dcc8 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/KerasTensor.cs @@ -0,0 +1,40 @@ +namespace Tensorflow.Keras.Engine; + +/// +/// A representation of a Keras in/output during Functional API construction. +/// +public class KerasTensor +{ + private Tensor _tensor; + public void SetTensor(Tensors tensor) + => _tensor = tensor; + + private TensorSpec _type_spec; + private string _name; + + public KerasTensor(TensorSpec type_spec, string name = null) + { + _type_spec = type_spec; + _name = name; + } + + public static KerasTensor from_tensor(Tensor tensor) + { + var type_spec = tensor.ToTensorSpec(); + var kt = new KerasTensor(type_spec, name: tensor.name); + kt.SetTensor(tensor); + return kt; + } + + public static implicit operator Tensors(KerasTensor kt) + => kt._tensor; + + public static implicit operator Tensor(KerasTensor kt) + => kt._tensor; + + public static implicit operator KerasTensor(Tensor tensor) + => from_tensor(tensor); + + public static implicit operator KerasTensor(Tensors tensors) + => from_tensor(tensors.First()); +} diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs index 364800ae..574cf599 100644 --- a/src/TensorFlowNET.Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Keras/BackendImpl.cs @@ -76,7 +76,7 @@ namespace Tensorflow.Keras _GRAPH_VARIABLES[graph.graph_key] = v; } - public Tensor placeholder(Shape shape = null, + public KerasTensor placeholder(Shape shape = null, int ndim = -1, TF_DataType dtype = TF_DataType.DtInvalid, bool sparse = false, diff --git a/src/TensorFlowNET.Keras/GlobalUsing.cs b/src/TensorFlowNET.Keras/GlobalUsing.cs index bc0798ed..85cd9194 100644 --- a/src/TensorFlowNET.Keras/GlobalUsing.cs +++ b/src/TensorFlowNET.Keras/GlobalUsing.cs @@ -4,4 +4,5 @@ global using System.Text; global using System.Linq; global using static Tensorflow.Binding; global using static Tensorflow.KerasApi; -global using Tensorflow.NumPy; \ No newline at end of file +global using Tensorflow.NumPy; +global using Tensorflow.Keras.Engine; \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index 213b53a8..5968461d 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -466,7 +466,7 @@ namespace Tensorflow.Keras.Layers /// In this case, values of 'None' in the 'shape' argument represent ragged dimensions. For more information about RaggedTensors, see this guide. /// /// A tensor. - public Tensors Input(Shape shape = null, + public KerasTensor Input(Shape shape = null, int batch_size = -1, string name = null, TF_DataType dtype = TF_DataType.DtInvalid,