| @@ -3,4 +3,6 @@ global using System.Collections.Generic; | |||||
| global using System.Text; | global using System.Text; | ||||
| global using System.Collections; | global using System.Collections; | ||||
| global using System.Data; | global using System.Data; | ||||
| global using System.Linq; | |||||
| global using System.Linq; | |||||
| global using Tensorflow.Keras.Engine; | |||||
| global using Tensorflow.Framework.Models; | |||||
| @@ -1,5 +1,6 @@ | |||||
| using System; | using System; | ||||
| using Tensorflow.Framework.Models; | using Tensorflow.Framework.Models; | ||||
| using Tensorflow.Keras.Engine; | |||||
| using Tensorflow.Keras.Layers.Rnn; | using Tensorflow.Keras.Layers.Rnn; | ||||
| using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
| using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; | using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; | ||||
| @@ -135,7 +136,7 @@ namespace Tensorflow.Keras.Layers | |||||
| public ILayer GlobalMaxPooling1D(string data_format = "channels_last"); | public ILayer GlobalMaxPooling1D(string data_format = "channels_last"); | ||||
| public ILayer GlobalMaxPooling2D(string data_format = "channels_last"); | public ILayer GlobalMaxPooling2D(string data_format = "channels_last"); | ||||
| public Tensors Input(Shape shape = null, | |||||
| public KerasTensor Input(Shape shape = null, | |||||
| int batch_size = -1, | int batch_size = -1, | ||||
| string name = null, | string name = null, | ||||
| TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| @@ -0,0 +1,40 @@ | |||||
| namespace Tensorflow.Keras.Engine; | |||||
| /// <summary> | |||||
| /// A representation of a Keras in/output during Functional API construction. | |||||
| /// </summary> | |||||
| public class KerasTensor | |||||
| { | |||||
| private Tensor _tensor; | |||||
| public void SetTensor(Tensors tensor) | |||||
| => _tensor = tensor; | |||||
| private TensorSpec _type_spec; | |||||
| private string _name; | |||||
| public KerasTensor(TensorSpec type_spec, string name = null) | |||||
| { | |||||
| _type_spec = type_spec; | |||||
| _name = name; | |||||
| } | |||||
| public static KerasTensor from_tensor(Tensor tensor) | |||||
| { | |||||
| var type_spec = tensor.ToTensorSpec(); | |||||
| var kt = new KerasTensor(type_spec, name: tensor.name); | |||||
| kt.SetTensor(tensor); | |||||
| return kt; | |||||
| } | |||||
| public static implicit operator Tensors(KerasTensor kt) | |||||
| => kt._tensor; | |||||
| public static implicit operator Tensor(KerasTensor kt) | |||||
| => kt._tensor; | |||||
| public static implicit operator KerasTensor(Tensor tensor) | |||||
| => from_tensor(tensor); | |||||
| public static implicit operator KerasTensor(Tensors tensors) | |||||
| => from_tensor(tensors.First()); | |||||
| } | |||||
| @@ -76,7 +76,7 @@ namespace Tensorflow.Keras | |||||
| _GRAPH_VARIABLES[graph.graph_key] = v; | _GRAPH_VARIABLES[graph.graph_key] = v; | ||||
| } | } | ||||
| public Tensor placeholder(Shape shape = null, | |||||
| public KerasTensor placeholder(Shape shape = null, | |||||
| int ndim = -1, | int ndim = -1, | ||||
| TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| bool sparse = false, | bool sparse = false, | ||||
| @@ -4,4 +4,5 @@ global using System.Text; | |||||
| global using System.Linq; | global using System.Linq; | ||||
| global using static Tensorflow.Binding; | global using static Tensorflow.Binding; | ||||
| global using static Tensorflow.KerasApi; | global using static Tensorflow.KerasApi; | ||||
| global using Tensorflow.NumPy; | |||||
| global using Tensorflow.NumPy; | |||||
| global using Tensorflow.Keras.Engine; | |||||
| @@ -466,7 +466,7 @@ namespace Tensorflow.Keras.Layers | |||||
| /// In this case, values of 'None' in the 'shape' argument represent ragged dimensions. For more information about RaggedTensors, see this guide. | /// In this case, values of 'None' in the 'shape' argument represent ragged dimensions. For more information about RaggedTensors, see this guide. | ||||
| /// </param> | /// </param> | ||||
| /// <returns>A tensor.</returns> | /// <returns>A tensor.</returns> | ||||
| public Tensors Input(Shape shape = null, | |||||
| public KerasTensor Input(Shape shape = null, | |||||
| int batch_size = -1, | int batch_size = -1, | ||||
| string name = null, | string name = null, | ||||
| TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||