|
|
|
@@ -22,6 +22,7 @@ using System.Collections.Generic; |
|
|
|
using System.Data; |
|
|
|
using System.Diagnostics; |
|
|
|
using System.Linq; |
|
|
|
using System.Reflection; |
|
|
|
using Tensorflow.Keras.ArgsDefinition; |
|
|
|
using Tensorflow.Keras.Engine; |
|
|
|
using Tensorflow.Keras.Layers; |
|
|
|
@@ -58,59 +59,32 @@ namespace Tensorflow.Keras.Utils |
|
|
|
|
|
|
|
public static Layer deserialize_keras_object(string class_name, JToken config) |
|
|
|
{ |
|
|
|
return class_name switch |
|
|
|
{ |
|
|
|
"Sequential" => new Sequential(config.ToObject<SequentialArgs>()), |
|
|
|
"InputLayer" => new InputLayer(config.ToObject<InputLayerArgs>()), |
|
|
|
"Flatten" => new Flatten(config.ToObject<FlattenArgs>()), |
|
|
|
"ELU" => new ELU(config.ToObject<ELUArgs>()), |
|
|
|
"Dense" => new Dense(config.ToObject<DenseArgs>()), |
|
|
|
"Softmax" => new Softmax(config.ToObject<SoftmaxArgs>()), |
|
|
|
"Conv2D" => new Conv2D(config.ToObject<Conv2DArgs>()), |
|
|
|
"BatchNormalization" => new BatchNormalization(config.ToObject<BatchNormalizationArgs>()), |
|
|
|
"MaxPooling2D" => new MaxPooling2D(config.ToObject<MaxPooling2DArgs>()), |
|
|
|
"Dropout" => new Dropout(config.ToObject<DropoutArgs>()), |
|
|
|
_ => throw new NotImplementedException($"The deserialization of <{class_name}> has not been supported. Usually it's a miss during the development. " + |
|
|
|
$"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues") |
|
|
|
}; |
|
|
|
var argType = Assembly.Load("Tensorflow.Binding").GetType($"Tensorflow.Keras.ArgsDefinition.{class_name}Args"); |
|
|
|
var deserializationMethod = typeof(JToken).GetMethods(BindingFlags.Instance | BindingFlags.Public) |
|
|
|
.Single(x => x.Name == "ToObject" && x.IsGenericMethodDefinition && x.GetParameters().Count() == 0); |
|
|
|
var deserializationGenericMethod = deserializationMethod.MakeGenericMethod(argType); |
|
|
|
var args = deserializationGenericMethod.Invoke(config, null); |
|
|
|
var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null); |
|
|
|
Debug.Assert(layer is Layer); |
|
|
|
return layer as Layer; |
|
|
|
} |
|
|
|
|
|
|
|
public static Layer deserialize_keras_object(string class_name, LayerArgs args) |
|
|
|
{ |
|
|
|
return class_name switch |
|
|
|
{ |
|
|
|
"Sequential" => new Sequential(args as SequentialArgs), |
|
|
|
"InputLayer" => new InputLayer(args as InputLayerArgs), |
|
|
|
"Flatten" => new Flatten(args as FlattenArgs), |
|
|
|
"ELU" => new ELU(args as ELUArgs), |
|
|
|
"Dense" => new Dense(args as DenseArgs), |
|
|
|
"Softmax" => new Softmax(args as SoftmaxArgs), |
|
|
|
"Conv2D" => new Conv2D(args as Conv2DArgs), |
|
|
|
"BatchNormalization" => new BatchNormalization(args as BatchNormalizationArgs), |
|
|
|
"MaxPooling2D" => new MaxPooling2D(args as MaxPooling2DArgs), |
|
|
|
"Dropout" => new Dropout(args as DropoutArgs), |
|
|
|
_ => throw new NotImplementedException($"The deserialization of <{class_name}> has not been supported. Usually it's a miss during the development. " + |
|
|
|
$"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues") |
|
|
|
}; |
|
|
|
var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null); |
|
|
|
Debug.Assert(layer is Layer); |
|
|
|
return layer as Layer; |
|
|
|
} |
|
|
|
|
|
|
|
public static LayerArgs? deserialize_layer_args(string class_name, JToken config) |
|
|
|
public static LayerArgs deserialize_layer_args(string class_name, JToken config) |
|
|
|
{ |
|
|
|
return class_name switch |
|
|
|
{ |
|
|
|
"Sequential" => config.ToObject<SequentialArgs>(), |
|
|
|
"InputLayer" => config.ToObject<InputLayerArgs>(), |
|
|
|
"Flatten" => config.ToObject<FlattenArgs>(), |
|
|
|
"ELU" => config.ToObject<ELUArgs>(), |
|
|
|
"Dense" => config.ToObject<DenseArgs>(), |
|
|
|
"Softmax" => config.ToObject<SoftmaxArgs>(), |
|
|
|
"Conv2D" => config.ToObject<Conv2DArgs>(), |
|
|
|
"BatchNormalization" => config.ToObject<BatchNormalizationArgs>(), |
|
|
|
"MaxPooling2D" => config.ToObject<MaxPooling2DArgs>(), |
|
|
|
"Dropout" => config.ToObject<DropoutArgs>(), |
|
|
|
_ => throw new NotImplementedException($"The deserialization of <{class_name}> has not been supported. Usually it's a miss during the development. " + |
|
|
|
$"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues") |
|
|
|
}; |
|
|
|
var argType = Assembly.Load("Tensorflow.Binding").GetType($"Tensorflow.Keras.ArgsDefinition.{class_name}Args"); |
|
|
|
var deserializationMethod = typeof(JToken).GetMethods(BindingFlags.Instance | BindingFlags.Public) |
|
|
|
.Single(x => x.Name == "ToObject" && x.IsGenericMethodDefinition && x.GetParameters().Count() == 0); |
|
|
|
var deserializationGenericMethod = deserializationMethod.MakeGenericMethod(argType); |
|
|
|
var args = deserializationGenericMethod.Invoke(config, null); |
|
|
|
Debug.Assert(args is LayerArgs); |
|
|
|
return args as LayerArgs; |
|
|
|
} |
|
|
|
|
|
|
|
public static ModelConfig deserialize_model_config(JToken json) |
|
|
|
|