Browse Source

Resolve the comments and errors.

pull/989/head
Yaohui Liu 2 years ago
parent
commit
0060039a73
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
3 changed files with 38 additions and 61 deletions
  1. +17
    -11
      src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs
  2. +1
    -4
      src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs
  3. +20
    -46
      src/TensorFlowNET.Keras/Utils/generic_utils.cs

+ 17
- 11
src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs View File

@@ -4,26 +4,26 @@ using System.IO;
using System.Linq; using System.Linq;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Text; using System.Text;
using Tensorflow.Util;


namespace Tensorflow.Checkpoint namespace Tensorflow.Checkpoint
{ {
public class CheckpointReader : IDisposable
public class CheckpointReader : SafeTensorflowHandle
{ {
private IntPtr _reader;
public Dictionary<string, TF_DataType> VariableToDataTypeMap { get; set; } public Dictionary<string, TF_DataType> VariableToDataTypeMap { get; set; }
public Dictionary<string, Shape> VariableToShapeMap { get; set; } public Dictionary<string, Shape> VariableToShapeMap { get; set; }


public CheckpointReader(string filename) public CheckpointReader(string filename)
{ {
Status status = new Status(); Status status = new Status();
_reader = c_api.TF_NewCheckpointReader(filename, status.Handle);
handle = c_api.TF_NewCheckpointReader(filename, status.Handle);
status.Check(true); status.Check(true);
ReadAllShapeAndType(); ReadAllShapeAndType();
} }


public int HasTensor(string name) public int HasTensor(string name)
{ {
return c_api.TF_CheckpointReaderHasTensor(_reader, name);
return c_api.TF_CheckpointReaderHasTensor(handle, name);
} }


/// <summary> /// <summary>
@@ -33,17 +33,17 @@ namespace Tensorflow.Checkpoint
/// <returns></returns> /// <returns></returns>
public string GetVariable(int index) public string GetVariable(int index)
{ {
return c_api.TF_CheckpointReaderGetVariable(_reader, index);
return c_api.TF_CheckpointReaderGetVariable(handle, index);
} }


public int Size() public int Size()
{ {
return c_api.TF_CheckpointReaderSize(_reader);
return c_api.TF_CheckpointReaderSize(handle);
} }


public TF_DataType GetVariableDataType(string name) public TF_DataType GetVariableDataType(string name)
{ {
return c_api.TF_CheckpointReaderGetVariableDataType(_reader, name);
return c_api.TF_CheckpointReaderGetVariableDataType(handle, name);
} }


public Shape GetVariableShape(string name) public Shape GetVariableShape(string name)
@@ -52,20 +52,20 @@ namespace Tensorflow.Checkpoint
int num_dims = GetVariableNumDims(name); int num_dims = GetVariableNumDims(name);
long[] dims = new long[num_dims]; long[] dims = new long[num_dims];
Status status = new Status(); Status status = new Status();
c_api.TF_CheckpointReaderGetVariableShape(_reader, name, dims, num_dims, status.Handle);
c_api.TF_CheckpointReaderGetVariableShape(handle, name, dims, num_dims, status.Handle);
status.Check(true); status.Check(true);
return new Shape(dims); return new Shape(dims);
} }


public int GetVariableNumDims(string name) public int GetVariableNumDims(string name)
{ {
return c_api.TF_CheckpointReaderGetVariableNumDims(_reader, name);
return c_api.TF_CheckpointReaderGetVariableNumDims(handle, name);
} }


public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid) public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid)
{ {
Status status = new Status(); Status status = new Status();
var tensor = c_api.TF_CheckpointReaderGetTensor(_reader, name, status.Handle);
var tensor = c_api.TF_CheckpointReaderGetTensor(handle, name, status.Handle);
status.Check(true); status.Check(true);
var shape = GetVariableShape(name); var shape = GetVariableShape(name);
if(dtype == TF_DataType.DtInvalid) if(dtype == TF_DataType.DtInvalid)
@@ -90,9 +90,15 @@ namespace Tensorflow.Checkpoint
} }
} }


protected override bool ReleaseHandle()
{
c_api.TF_DeleteCheckpointReader(handle);
return true;
}

public void Dispose() public void Dispose()
{ {
c_api.TF_DeleteCheckpointReader(_reader);
c_api.TF_DeleteCheckpointReader(handle);
} }
} }
} }

+ 1
- 4
src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs View File

@@ -1,8 +1,5 @@
using Newtonsoft.Json.Linq; using Newtonsoft.Json.Linq;
using Newtonsoft.Json; using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;


namespace Tensorflow.Keras.Common namespace Tensorflow.Keras.Common
{ {
@@ -19,7 +16,7 @@ namespace Tensorflow.Keras.Common


public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
{ {
var token = JToken.FromObject(value);
var token = JToken.FromObject(dtypes.as_numpy_name((TF_DataType)value));
token.WriteTo(writer); token.WriteTo(writer);
} }




+ 20
- 46
src/TensorFlowNET.Keras/Utils/generic_utils.cs View File

@@ -22,6 +22,7 @@ using System.Collections.Generic;
using System.Data; using System.Data;
using System.Diagnostics; using System.Diagnostics;
using System.Linq; using System.Linq;
using System.Reflection;
using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers; using Tensorflow.Keras.Layers;
@@ -58,59 +59,32 @@ namespace Tensorflow.Keras.Utils


public static Layer deserialize_keras_object(string class_name, JToken config) public static Layer deserialize_keras_object(string class_name, JToken config)
{ {
return class_name switch
{
"Sequential" => new Sequential(config.ToObject<SequentialArgs>()),
"InputLayer" => new InputLayer(config.ToObject<InputLayerArgs>()),
"Flatten" => new Flatten(config.ToObject<FlattenArgs>()),
"ELU" => new ELU(config.ToObject<ELUArgs>()),
"Dense" => new Dense(config.ToObject<DenseArgs>()),
"Softmax" => new Softmax(config.ToObject<SoftmaxArgs>()),
"Conv2D" => new Conv2D(config.ToObject<Conv2DArgs>()),
"BatchNormalization" => new BatchNormalization(config.ToObject<BatchNormalizationArgs>()),
"MaxPooling2D" => new MaxPooling2D(config.ToObject<MaxPooling2DArgs>()),
"Dropout" => new Dropout(config.ToObject<DropoutArgs>()),
_ => throw new NotImplementedException($"The deserialization of <{class_name}> has not been supported. Usually it's a miss during the development. " +
$"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues")
};
var argType = Assembly.Load("Tensorflow.Binding").GetType($"Tensorflow.Keras.ArgsDefinition.{class_name}Args");
var deserializationMethod = typeof(JToken).GetMethods(BindingFlags.Instance | BindingFlags.Public)
.Single(x => x.Name == "ToObject" && x.IsGenericMethodDefinition && x.GetParameters().Count() == 0);
var deserializationGenericMethod = deserializationMethod.MakeGenericMethod(argType);
var args = deserializationGenericMethod.Invoke(config, null);
var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null);
Debug.Assert(layer is Layer);
return layer as Layer;
} }


public static Layer deserialize_keras_object(string class_name, LayerArgs args) public static Layer deserialize_keras_object(string class_name, LayerArgs args)
{ {
return class_name switch
{
"Sequential" => new Sequential(args as SequentialArgs),
"InputLayer" => new InputLayer(args as InputLayerArgs),
"Flatten" => new Flatten(args as FlattenArgs),
"ELU" => new ELU(args as ELUArgs),
"Dense" => new Dense(args as DenseArgs),
"Softmax" => new Softmax(args as SoftmaxArgs),
"Conv2D" => new Conv2D(args as Conv2DArgs),
"BatchNormalization" => new BatchNormalization(args as BatchNormalizationArgs),
"MaxPooling2D" => new MaxPooling2D(args as MaxPooling2DArgs),
"Dropout" => new Dropout(args as DropoutArgs),
_ => throw new NotImplementedException($"The deserialization of <{class_name}> has not been supported. Usually it's a miss during the development. " +
$"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues")
};
var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null);
Debug.Assert(layer is Layer);
return layer as Layer;
} }


public static LayerArgs? deserialize_layer_args(string class_name, JToken config)
public static LayerArgs deserialize_layer_args(string class_name, JToken config)
{ {
return class_name switch
{
"Sequential" => config.ToObject<SequentialArgs>(),
"InputLayer" => config.ToObject<InputLayerArgs>(),
"Flatten" => config.ToObject<FlattenArgs>(),
"ELU" => config.ToObject<ELUArgs>(),
"Dense" => config.ToObject<DenseArgs>(),
"Softmax" => config.ToObject<SoftmaxArgs>(),
"Conv2D" => config.ToObject<Conv2DArgs>(),
"BatchNormalization" => config.ToObject<BatchNormalizationArgs>(),
"MaxPooling2D" => config.ToObject<MaxPooling2DArgs>(),
"Dropout" => config.ToObject<DropoutArgs>(),
_ => throw new NotImplementedException($"The deserialization of <{class_name}> has not been supported. Usually it's a miss during the development. " +
$"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues")
};
var argType = Assembly.Load("Tensorflow.Binding").GetType($"Tensorflow.Keras.ArgsDefinition.{class_name}Args");
var deserializationMethod = typeof(JToken).GetMethods(BindingFlags.Instance | BindingFlags.Public)
.Single(x => x.Name == "ToObject" && x.IsGenericMethodDefinition && x.GetParameters().Count() == 0);
var deserializationGenericMethod = deserializationMethod.MakeGenericMethod(argType);
var args = deserializationGenericMethod.Invoke(config, null);
Debug.Assert(args is LayerArgs);
return args as LayerArgs;
} }


public static ModelConfig deserialize_model_config(JToken json) public static ModelConfig deserialize_model_config(JToken json)


Loading…
Cancel
Save