| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||||
| using System.Diagnostics; | using System.Diagnostics; | ||||
| using System.IO; | using System.IO; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Functions; | |||||
| using Tensorflow.Train; | using Tensorflow.Train; | ||||
| using Tensorflow.Training; | using Tensorflow.Training; | ||||
| using pbc = global::Google.Protobuf.Collections; | using pbc = global::Google.Protobuf.Collections; | ||||
| @@ -13,7 +14,7 @@ public static class CheckPointUtils | |||||
| { | { | ||||
| private static string _ESCAPE_CHAR = "."; | private static string _ESCAPE_CHAR = "."; | ||||
| public static (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>, IDictionary<Trackable, int>, | public static (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>, IDictionary<Trackable, int>, | ||||
| IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>, | |||||
| IDictionary<Trackable, pbc::RepeatedField<TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>, | |||||
| IDictionary<Trackable, string>) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view) | IDictionary<Trackable, string>) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view) | ||||
| { | { | ||||
| var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); | var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); | ||||
| @@ -93,13 +93,14 @@ public class SaveableView | |||||
| // | // | ||||
| // } | // } | ||||
| foreach (var obj in _nodes) | |||||
| { | |||||
| if (obj is ConcreteFunction) | |||||
| { | |||||
| _concrete_functions.Add((ConcreteFunction)obj); | |||||
| } | |||||
| } | |||||
| //_concrete_functions = new(); | |||||
| //foreach (var obj in _nodes) | |||||
| //{ | |||||
| // if (obj is ConcreteFunction) | |||||
| // { | |||||
| // _concrete_functions.Add((ConcreteFunction)obj); | |||||
| // } | |||||
| //} | |||||
| } | } | ||||
| public List<ConcreteFunction> get_concrete_resource_initializers() | public List<ConcreteFunction> get_concrete_resource_initializers() | ||||
| @@ -225,8 +226,8 @@ public class SaveableView | |||||
| } | } | ||||
| else if (obj is ConcreteFunction) | else if (obj is ConcreteFunction) | ||||
| { | { | ||||
| // TODO: complete it. | |||||
| throw new NotImplementedException(); | |||||
| // TODO(Rinne): complete it. | |||||
| // throw new NotImplementedException(); | |||||
| } | } | ||||
| // skip the process of type `_CapturedTensor` and `CapturableResource`. | // skip the process of type `_CapturedTensor` and `CapturableResource`. | ||||
| else | else | ||||
| @@ -17,7 +17,14 @@ namespace Tensorflow | |||||
| { | { | ||||
| protected string _name; | protected string _name; | ||||
| public virtual string Name => _handle_name; | public virtual string Name => _handle_name; | ||||
| public virtual string SharedName => _name; | |||||
| public virtual string SharedName | |||||
| { | |||||
| get | |||||
| { | |||||
| // TODO(Rinne): optimize the implementation with refactor of variable. | |||||
| return _handle_name.Substring(0, _handle_name.IndexOf(':') + 1); | |||||
| } | |||||
| } | |||||
| protected TF_DataType _dtype; | protected TF_DataType _dtype; | ||||
| public TF_DataType dtype => _dtype; | public TF_DataType dtype => _dtype; | ||||
| protected string _handle_name; | protected string _handle_name; | ||||
| @@ -152,6 +152,39 @@ namespace Tensorflow.Keras.Saving | |||||
| _reconstruct_all_models(); | _reconstruct_all_models(); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Removes tracked references that are only used when loading the model. | |||||
| /// Now that the node object has been fully loaded, and the checkpoint has | |||||
| /// been restored, the object no longer needs to track objects added from | |||||
| /// SerializedAttributes. (Note that saving a training checkpoint still | |||||
| /// functions correctly, because layers and variables are tracked | |||||
| /// separately by the Layer object.) | |||||
| /// </summary> | |||||
| public void del_tracking() | |||||
| { | |||||
| foreach(var (node, _) in loaded_nodes.Values) | |||||
| { | |||||
| if(node is not Layer layer) | |||||
| { | |||||
| continue; | |||||
| } | |||||
| foreach(var name in PUBLIC_ATTRIBUTES.Keys) | |||||
| { | |||||
| layer._delete_tracking(name); | |||||
| } | |||||
| if(node is Functional functional) | |||||
| { | |||||
| foreach(var name in functional.UnconditionalDependencyNames.Keys) | |||||
| { | |||||
| if(Regex.Match(name, @"^layer(_with_weights)?-[\d+]").Success) | |||||
| { | |||||
| functional._delete_tracking(name); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| private void _reconstruct_all_models() | private void _reconstruct_all_models() | ||||
| { | { | ||||
| HashSet<int> all_initialized_models = new(); | HashSet<int> all_initialized_models = new(); | ||||
| @@ -77,7 +77,7 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||||
| var loaded = Loader.load_partial(path, nodes_to_load, options); | var loaded = Loader.load_partial(path, nodes_to_load, options); | ||||
| keras_loader.finalize_objects(); | keras_loader.finalize_objects(); | ||||
| // keras_loader.del_tracking(); | |||||
| keras_loader.del_tracking(); | |||||
| var model = loaded["root"]; | var model = loaded["root"]; | ||||
| @@ -196,5 +196,17 @@ namespace Tensorflow.Keras.UnitTest.Model | |||||
| // ) | // ) | ||||
| #endregion | #endregion | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void SaveAfterLoad() | |||||
| { | |||||
| var model = tf.keras.models.load_model(@"Assets/simple_model_from_auto_compile"); | |||||
| model.summary(); | |||||
| model.save("Assets/saved_auto_compile_after_loading"); | |||||
| //model = tf.keras.models.load_model(@"Assets/saved_auto_compile_after_loading"); | |||||
| //model.summary(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||