Browse Source

Resolve the comments and errors.

pull/989/head
Yaohui Liu 2 years ago
parent
commit
0060039a73
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
3 changed files with 38 additions and 61 deletions
  1. +17
    -11
      src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs
  2. +1
    -4
      src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs
  3. +20
    -46
      src/TensorFlowNET.Keras/Utils/generic_utils.cs

+ 17
- 11
src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs View File

@@ -4,26 +4,26 @@ using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using Tensorflow.Util;

namespace Tensorflow.Checkpoint
{
public class CheckpointReader : IDisposable
public class CheckpointReader : SafeTensorflowHandle
{
private IntPtr _reader;
public Dictionary<string, TF_DataType> VariableToDataTypeMap { get; set; }
public Dictionary<string, Shape> VariableToShapeMap { get; set; }

public CheckpointReader(string filename)
{
Status status = new Status();
_reader = c_api.TF_NewCheckpointReader(filename, status.Handle);
handle = c_api.TF_NewCheckpointReader(filename, status.Handle);
status.Check(true);
ReadAllShapeAndType();
}

public int HasTensor(string name)
{
return c_api.TF_CheckpointReaderHasTensor(_reader, name);
return c_api.TF_CheckpointReaderHasTensor(handle, name);
}

/// <summary>
@@ -33,17 +33,17 @@ namespace Tensorflow.Checkpoint
/// <returns></returns>
public string GetVariable(int index)
{
return c_api.TF_CheckpointReaderGetVariable(_reader, index);
return c_api.TF_CheckpointReaderGetVariable(handle, index);
}

public int Size()
{
return c_api.TF_CheckpointReaderSize(_reader);
return c_api.TF_CheckpointReaderSize(handle);
}

public TF_DataType GetVariableDataType(string name)
{
return c_api.TF_CheckpointReaderGetVariableDataType(_reader, name);
return c_api.TF_CheckpointReaderGetVariableDataType(handle, name);
}

public Shape GetVariableShape(string name)
@@ -52,20 +52,20 @@ namespace Tensorflow.Checkpoint
int num_dims = GetVariableNumDims(name);
long[] dims = new long[num_dims];
Status status = new Status();
c_api.TF_CheckpointReaderGetVariableShape(_reader, name, dims, num_dims, status.Handle);
c_api.TF_CheckpointReaderGetVariableShape(handle, name, dims, num_dims, status.Handle);
status.Check(true);
return new Shape(dims);
}

public int GetVariableNumDims(string name)
{
return c_api.TF_CheckpointReaderGetVariableNumDims(_reader, name);
return c_api.TF_CheckpointReaderGetVariableNumDims(handle, name);
}

public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid)
{
Status status = new Status();
var tensor = c_api.TF_CheckpointReaderGetTensor(_reader, name, status.Handle);
var tensor = c_api.TF_CheckpointReaderGetTensor(handle, name, status.Handle);
status.Check(true);
var shape = GetVariableShape(name);
if(dtype == TF_DataType.DtInvalid)
@@ -90,9 +90,15 @@ namespace Tensorflow.Checkpoint
}
}

protected override bool ReleaseHandle()
{
c_api.TF_DeleteCheckpointReader(handle);
return true;
}

public void Dispose()
{
c_api.TF_DeleteCheckpointReader(_reader);
c_api.TF_DeleteCheckpointReader(handle);
}
}
}

+ 1
- 4
src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs View File

@@ -1,8 +1,5 @@
using Newtonsoft.Json.Linq;
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Common
{
@@ -19,7 +16,7 @@ namespace Tensorflow.Keras.Common

public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
{
var token = JToken.FromObject(value);
var token = JToken.FromObject(dtypes.as_numpy_name((TF_DataType)value));
token.WriteTo(writer);
}



+ 20
- 46
src/TensorFlowNET.Keras/Utils/generic_utils.cs View File

@@ -22,6 +22,7 @@ using System.Collections.Generic;
using System.Data;
using System.Diagnostics;
using System.Linq;
using System.Reflection;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;
@@ -58,59 +59,32 @@ namespace Tensorflow.Keras.Utils

public static Layer deserialize_keras_object(string class_name, JToken config)
{
return class_name switch
{
"Sequential" => new Sequential(config.ToObject<SequentialArgs>()),
"InputLayer" => new InputLayer(config.ToObject<InputLayerArgs>()),
"Flatten" => new Flatten(config.ToObject<FlattenArgs>()),
"ELU" => new ELU(config.ToObject<ELUArgs>()),
"Dense" => new Dense(config.ToObject<DenseArgs>()),
"Softmax" => new Softmax(config.ToObject<SoftmaxArgs>()),
"Conv2D" => new Conv2D(config.ToObject<Conv2DArgs>()),
"BatchNormalization" => new BatchNormalization(config.ToObject<BatchNormalizationArgs>()),
"MaxPooling2D" => new MaxPooling2D(config.ToObject<MaxPooling2DArgs>()),
"Dropout" => new Dropout(config.ToObject<DropoutArgs>()),
_ => throw new NotImplementedException($"The deserialization of <{class_name}> has not been supported. Usually it's a miss during the development. " +
$"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues")
};
var argType = Assembly.Load("Tensorflow.Binding").GetType($"Tensorflow.Keras.ArgsDefinition.{class_name}Args");
var deserializationMethod = typeof(JToken).GetMethods(BindingFlags.Instance | BindingFlags.Public)
.Single(x => x.Name == "ToObject" && x.IsGenericMethodDefinition && x.GetParameters().Count() == 0);
var deserializationGenericMethod = deserializationMethod.MakeGenericMethod(argType);
var args = deserializationGenericMethod.Invoke(config, null);
var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null);
Debug.Assert(layer is Layer);
return layer as Layer;
}

public static Layer deserialize_keras_object(string class_name, LayerArgs args)
{
return class_name switch
{
"Sequential" => new Sequential(args as SequentialArgs),
"InputLayer" => new InputLayer(args as InputLayerArgs),
"Flatten" => new Flatten(args as FlattenArgs),
"ELU" => new ELU(args as ELUArgs),
"Dense" => new Dense(args as DenseArgs),
"Softmax" => new Softmax(args as SoftmaxArgs),
"Conv2D" => new Conv2D(args as Conv2DArgs),
"BatchNormalization" => new BatchNormalization(args as BatchNormalizationArgs),
"MaxPooling2D" => new MaxPooling2D(args as MaxPooling2DArgs),
"Dropout" => new Dropout(args as DropoutArgs),
_ => throw new NotImplementedException($"The deserialization of <{class_name}> has not been supported. Usually it's a miss during the development. " +
$"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues")
};
var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null);
Debug.Assert(layer is Layer);
return layer as Layer;
}

public static LayerArgs? deserialize_layer_args(string class_name, JToken config)
public static LayerArgs deserialize_layer_args(string class_name, JToken config)
{
return class_name switch
{
"Sequential" => config.ToObject<SequentialArgs>(),
"InputLayer" => config.ToObject<InputLayerArgs>(),
"Flatten" => config.ToObject<FlattenArgs>(),
"ELU" => config.ToObject<ELUArgs>(),
"Dense" => config.ToObject<DenseArgs>(),
"Softmax" => config.ToObject<SoftmaxArgs>(),
"Conv2D" => config.ToObject<Conv2DArgs>(),
"BatchNormalization" => config.ToObject<BatchNormalizationArgs>(),
"MaxPooling2D" => config.ToObject<MaxPooling2DArgs>(),
"Dropout" => config.ToObject<DropoutArgs>(),
_ => throw new NotImplementedException($"The deserialization of <{class_name}> has not been supported. Usually it's a miss during the development. " +
$"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues")
};
var argType = Assembly.Load("Tensorflow.Binding").GetType($"Tensorflow.Keras.ArgsDefinition.{class_name}Args");
var deserializationMethod = typeof(JToken).GetMethods(BindingFlags.Instance | BindingFlags.Public)
.Single(x => x.Name == "ToObject" && x.IsGenericMethodDefinition && x.GetParameters().Count() == 0);
var deserializationGenericMethod = deserializationMethod.MakeGenericMethod(argType);
var args = deserializationGenericMethod.Invoke(config, null);
Debug.Assert(args is LayerArgs);
return args as LayerArgs;
}

public static ModelConfig deserialize_model_config(JToken json)


Loading…
Cancel
Save