Fix the error when loading VGG19.tags/v0.100.5-BERT-load
| @@ -9,7 +9,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
| /// This class has nothing but the attributes different from `LayerArgs`. | /// This class has nothing but the attributes different from `LayerArgs`. | ||||
| /// It's used to serialize the model to `tf` format. | /// It's used to serialize the model to `tf` format. | ||||
| /// If the `get_config` of a `Layer` in python code of tensorflow contains `super().get_config`, | /// If the `get_config` of a `Layer` in python code of tensorflow contains `super().get_config`, | ||||
| /// then the Arg definition should inherit `utoSerializeLayerArgs` instead of `LayerArgs`. | |||||
| /// then the Arg definition should inherit `AutoSerializeLayerArgs` instead of `LayerArgs`. | |||||
| /// </summary> | /// </summary> | ||||
| public class AutoSerializeLayerArgs: LayerArgs | public class AutoSerializeLayerArgs: LayerArgs | ||||
| { | { | ||||
| @@ -7,6 +7,11 @@ using System.Text; | |||||
| namespace Tensorflow.Keras.Common | namespace Tensorflow.Keras.Common | ||||
| { | { | ||||
| class ShapeInfoFromPython | |||||
| { | |||||
| public string class_name { get; set; } | |||||
| public long?[] items { get; set; } | |||||
| } | |||||
| public class CustomizedShapeJsonConverter: JsonConverter | public class CustomizedShapeJsonConverter: JsonConverter | ||||
| { | { | ||||
| public override bool CanConvert(Type objectType) | public override bool CanConvert(Type objectType) | ||||
| @@ -44,36 +49,23 @@ namespace Tensorflow.Keras.Common | |||||
| dims[i] = shape.dims[i]; | dims[i] = shape.dims[i]; | ||||
| } | } | ||||
| } | } | ||||
| var token = JToken.FromObject(dims); | |||||
| var token = JToken.FromObject(new ShapeInfoFromPython() | |||||
| { | |||||
| class_name = "__tuple__", | |||||
| items = dims | |||||
| }); | |||||
| token.WriteTo(writer); | token.WriteTo(writer); | ||||
| } | } | ||||
| } | } | ||||
| public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | ||||
| { | { | ||||
| long?[] dims; | |||||
| try | |||||
| { | |||||
| dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; | |||||
| } | |||||
| catch (JsonSerializationException ex) | |||||
| { | |||||
| if (reader.Value.Equals("class_name")) | |||||
| { | |||||
| reader.Read(); | |||||
| reader.Read(); | |||||
| reader.Read(); | |||||
| dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; | |||||
| } | |||||
| else | |||||
| { | |||||
| throw ex; | |||||
| } | |||||
| } | |||||
| if (dims is null) | |||||
| var shape_info_from_python = serializer.Deserialize<ShapeInfoFromPython>(reader); | |||||
| if (shape_info_from_python is null) | |||||
| { | { | ||||
| return null; | return null; | ||||
| } | } | ||||
| long ?[]dims = shape_info_from_python.items; | |||||
| long[] convertedDims = new long[dims.Length]; | long[] convertedDims = new long[dims.Length]; | ||||
| for(int i = 0; i < dims.Length; i++) | for(int i = 0; i < dims.Length; i++) | ||||
| { | { | ||||
| @@ -108,7 +108,7 @@ https://tensorflownet.readthedocs.io</Description> | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> | <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> | ||||
| <PackageReference Include="Newtonsoft.Json" Version="13.0.2" /> | |||||
| <PackageReference Include="Newtonsoft.Json" Version="13.0.3" /> | |||||
| <PackageReference Include="OneOf" Version="3.0.223" /> | <PackageReference Include="OneOf" Version="3.0.223" /> | ||||
| <PackageReference Include="Protobuf.Text" Version="0.7.0" /> | <PackageReference Include="Protobuf.Text" Version="0.7.0" /> | ||||
| <PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" /> | <PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" /> | ||||
| @@ -563,7 +563,7 @@ namespace Tensorflow | |||||
| return proto.KindCase switch | return proto.KindCase switch | ||||
| { | { | ||||
| SavedObject.KindOneofCase.UserObject => _recreate_user_object(proto.UserObject, node_id), | SavedObject.KindOneofCase.UserObject => _recreate_user_object(proto.UserObject, node_id), | ||||
| SavedObject.KindOneofCase.Function => _recreate_function(proto.Function, null), | |||||
| SavedObject.KindOneofCase.Function => _recreate_function(proto.Function, dependencies), | |||||
| SavedObject.KindOneofCase.BareConcreteFunction => _recreate_bare_concrete_function(proto.BareConcreteFunction, dependencies), | SavedObject.KindOneofCase.BareConcreteFunction => _recreate_bare_concrete_function(proto.BareConcreteFunction, dependencies), | ||||
| SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable), | SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable), | ||||
| SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException(), | SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException(), | ||||
| @@ -626,7 +626,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| private (Function, Action<object, object, object>) _recreate_function(SavedFunction proto, | private (Function, Action<object, object, object>) _recreate_function(SavedFunction proto, | ||||
| Dictionary<OneOf<string, int>, Trackable> dependencies) | |||||
| IDictionary<OneOf<string, int>, Trackable> dependencies) | |||||
| { | { | ||||
| var fn = function_deserialization.recreate_function(proto, _concrete_functions); | var fn = function_deserialization.recreate_function(proto, _concrete_functions); | ||||
| foreach (var name in proto.ConcreteFunctions) | foreach (var name in proto.ConcreteFunctions) | ||||
| @@ -644,6 +644,13 @@ namespace Tensorflow | |||||
| return (fn, setattr); | return (fn, setattr); | ||||
| } | } | ||||
| private (Tensor, Action<object, object, object>) _get_tensor_from_fn(CapturedTensor proto) | |||||
| { | |||||
| var outer_graph = _concrete_functions[proto.ConcreteFunction].func_graph; | |||||
| var captured_tensor = outer_graph.get_tensor_by_name(proto.Name); | |||||
| return (captured_tensor, setattr); | |||||
| } | |||||
| // TODO: remove this to a common class. | // TODO: remove this to a common class. | ||||
| public static Action<object, object, object> setattr = (x, y, z) => | public static Action<object, object, object> setattr = (x, y, z) => | ||||
| { | { | ||||
| @@ -71,6 +71,9 @@ namespace Tensorflow.Keras.Utils | |||||
| var args = deserializationGenericMethod.Invoke(config, null); | var args = deserializationGenericMethod.Invoke(config, null); | ||||
| var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null); | var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null); | ||||
| Debug.Assert(layer is Layer); | Debug.Assert(layer is Layer); | ||||
| // TODO(Rinne): _shared_object_loading_scope().set(shared_object_id, deserialized_obj) | |||||
| return layer as Layer; | return layer as Layer; | ||||
| } | } | ||||
| @@ -82,6 +85,9 @@ namespace Tensorflow.Keras.Utils | |||||
| return null; | return null; | ||||
| } | } | ||||
| Debug.Assert(layer is Layer); | Debug.Assert(layer is Layer); | ||||
| // TODO(Rinne): _shared_object_loading_scope().set(shared_object_id, deserialized_obj) | |||||
| return layer as Layer; | return layer as Layer; | ||||
| } | } | ||||
| @@ -6,13 +6,13 @@ using Tensorflow.Keras.Optimizers; | |||||
| using Tensorflow.Keras.UnitTest.Helpers; | using Tensorflow.Keras.UnitTest.Helpers; | ||||
| using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using static Tensorflow.KerasApi; | |||||
| namespace TensorFlowNET.Keras.UnitTest.SaveModel; | namespace TensorFlowNET.Keras.UnitTest.SaveModel; | ||||
| [TestClass] | [TestClass] | ||||
| public class SequentialModelLoad | public class SequentialModelLoad | ||||
| { | { | ||||
| [Ignore] | |||||
| [TestMethod] | [TestMethod] | ||||
| public void SimpleModelFromAutoCompile() | public void SimpleModelFromAutoCompile() | ||||
| { | { | ||||
| @@ -80,4 +80,27 @@ public class SequentialModelLoad | |||||
| model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | ||||
| } | } | ||||
| [Ignore] | |||||
| [TestMethod] | |||||
| public void VGG19() | |||||
| { | |||||
| var model = tf.keras.models.load_model(@"D:\development\tf.net\models\VGG19"); | |||||
| model.summary(); | |||||
| var classify_model = keras.Sequential(new System.Collections.Generic.List<Tensorflow.Keras.ILayer>() | |||||
| { | |||||
| model, | |||||
| keras.layers.Flatten(), | |||||
| keras.layers.Dense(10), | |||||
| }); | |||||
| classify_model.summary(); | |||||
| classify_model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | |||||
| var x = np.random.uniform(0, 1, (8, 512, 512, 3)); | |||||
| var y = np.ones((8)); | |||||
| classify_model.fit(x, y, batch_size: 4); | |||||
| } | |||||
| } | } | ||||