Browse Source

Initially adding KerasTensor. #1142

tags/v0.110.4-Transformer-Model
Haiping Chen 2 years ago
parent
commit
70f873ecce
6 changed files with 49 additions and 5 deletions
  1. +3
    -1
      src/TensorFlowNET.Core/GlobalUsing.cs
  2. +2
    -1
      src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
  3. +40
    -0
      src/TensorFlowNET.Core/Tensors/KerasTensor.cs
  4. +1
    -1
      src/TensorFlowNET.Keras/BackendImpl.cs
  5. +2
    -1
      src/TensorFlowNET.Keras/GlobalUsing.cs
  6. +1
    -1
      src/TensorFlowNET.Keras/Layers/LayersApi.cs

+ 3
- 1
src/TensorFlowNET.Core/GlobalUsing.cs View File

@@ -3,4 +3,6 @@ global using System.Collections.Generic;
global using System.Text; global using System.Text;
global using System.Collections; global using System.Collections;
global using System.Data; global using System.Data;
global using System.Linq;
global using System.Linq;
global using Tensorflow.Keras.Engine;
global using Tensorflow.Framework.Models;

+ 2
- 1
src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs View File

@@ -1,5 +1,6 @@
using System; using System;
using Tensorflow.Framework.Models; using Tensorflow.Framework.Models;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers.Rnn; using Tensorflow.Keras.Layers.Rnn;
using Tensorflow.NumPy; using Tensorflow.NumPy;
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; using static Google.Protobuf.Reflection.FieldDescriptorProto.Types;
@@ -135,7 +136,7 @@ namespace Tensorflow.Keras.Layers
public ILayer GlobalMaxPooling1D(string data_format = "channels_last"); public ILayer GlobalMaxPooling1D(string data_format = "channels_last");
public ILayer GlobalMaxPooling2D(string data_format = "channels_last"); public ILayer GlobalMaxPooling2D(string data_format = "channels_last");


public Tensors Input(Shape shape = null,
public KerasTensor Input(Shape shape = null,
int batch_size = -1, int batch_size = -1,
string name = null, string name = null,
TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype = TF_DataType.DtInvalid,


+ 40
- 0
src/TensorFlowNET.Core/Tensors/KerasTensor.cs View File

@@ -0,0 +1,40 @@
namespace Tensorflow.Keras.Engine;

/// <summary>
/// A representation of a Keras in/output during Functional API construction.
/// </summary>
public class KerasTensor
{
private Tensor _tensor;
public void SetTensor(Tensors tensor)
=> _tensor = tensor;

private TensorSpec _type_spec;
private string _name;

public KerasTensor(TensorSpec type_spec, string name = null)
{
_type_spec = type_spec;
_name = name;
}

public static KerasTensor from_tensor(Tensor tensor)
{
var type_spec = tensor.ToTensorSpec();
var kt = new KerasTensor(type_spec, name: tensor.name);
kt.SetTensor(tensor);
return kt;
}

public static implicit operator Tensors(KerasTensor kt)
=> kt._tensor;

public static implicit operator Tensor(KerasTensor kt)
=> kt._tensor;

public static implicit operator KerasTensor(Tensor tensor)
=> from_tensor(tensor);

public static implicit operator KerasTensor(Tensors tensors)
=> from_tensor(tensors.First());
}

+ 1
- 1
src/TensorFlowNET.Keras/BackendImpl.cs View File

@@ -76,7 +76,7 @@ namespace Tensorflow.Keras
_GRAPH_VARIABLES[graph.graph_key] = v; _GRAPH_VARIABLES[graph.graph_key] = v;
} }


public Tensor placeholder(Shape shape = null,
public KerasTensor placeholder(Shape shape = null,
int ndim = -1, int ndim = -1,
TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype = TF_DataType.DtInvalid,
bool sparse = false, bool sparse = false,


+ 2
- 1
src/TensorFlowNET.Keras/GlobalUsing.cs View File

@@ -4,4 +4,5 @@ global using System.Text;
global using System.Linq; global using System.Linq;
global using static Tensorflow.Binding; global using static Tensorflow.Binding;
global using static Tensorflow.KerasApi; global using static Tensorflow.KerasApi;
global using Tensorflow.NumPy;
global using Tensorflow.NumPy;
global using Tensorflow.Keras.Engine;

+ 1
- 1
src/TensorFlowNET.Keras/Layers/LayersApi.cs View File

@@ -466,7 +466,7 @@ namespace Tensorflow.Keras.Layers
/// In this case, values of 'None' in the 'shape' argument represent ragged dimensions. For more information about RaggedTensors, see this guide. /// In this case, values of 'None' in the 'shape' argument represent ragged dimensions. For more information about RaggedTensors, see this guide.
/// </param> /// </param>
/// <returns>A tensor.</returns> /// <returns>A tensor.</returns>
public Tensors Input(Shape shape = null,
public KerasTensor Input(Shape shape = null,
int batch_size = -1, int batch_size = -1,
string name = null, string name = null,
TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype = TF_DataType.DtInvalid,


Loading…
Cancel
Save