From de082e1029311ea7678059279946b25f422e8b41 Mon Sep 17 00:00:00 2001 From: haiping008 Date: Wed, 13 Feb 2019 10:32:24 -0600 Subject: [PATCH] tf.train.import_meta_graph passed. --- .../Framework/meta_graph.py.cs | 11 ++++-- src/TensorFlowNET.Core/Graphs/Graph.cs | 13 +++++++ .../Train/Saving/BaseSaverBuilder.cs | 15 ++++++-- src/TensorFlowNET.Core/Train/Saving/Saver.cs | 2 +- .../Train/Saving/saver.py.cs | 36 ++++++++++++++++--- src/TensorFlowNET.Core/Train/tf.optimizers.cs | 2 +- src/TensorFlowNET.Core/ops.GraphKeys.cs | 2 +- src/TensorFlowNET.Core/ops.py.cs | 12 +++++++ 8 files changed, 79 insertions(+), 14 deletions(-) diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs b/src/TensorFlowNET.Core/Framework/meta_graph.py.cs index 580bc66d..6972ce1d 100644 --- a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs +++ b/src/TensorFlowNET.Core/Framework/meta_graph.py.cs @@ -18,7 +18,7 @@ namespace Tensorflow return meta_graph_def; } - public static (RefVariable[], string[]) import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file, + public static (Dictionary, ITensorOrOperation[]) import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file, bool clear_devices = false, string import_scope = "", Dictionary input_map = null, @@ -89,7 +89,7 @@ namespace Tensorflow variable = new RefVariable(variable_def: proto, import_scope: scope_to_prepend_to_names); variable_objects[value] = variable; } - + variable = variable_objects[value]; graph.add_to_collection(col.Key, variable); } } @@ -102,7 +102,12 @@ namespace Tensorflow } } - return (null, null); + var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, + scope: scope_to_prepend_to_names) as List; + var var_list = new Dictionary(); + variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v); + + return (var_list, imported_return_elements); } /// diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index b281debf..da80f021 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -26,6 +26,12 @@ namespace Tensorflow public string _graph_key; public Status Status { get; } + /// + /// True if the graph is considered "finalized". In that case no + /// new operations can be added. + /// + private bool _finalized = false; + /// /// Arbitrary collections of objects. /// @@ -126,6 +132,7 @@ namespace Tensorflow public void add_to_collection(string name, T value) { + _check_not_finalized(); if (_collections.ContainsKey(name)) (_collections[name] as List).Add(value); else @@ -138,6 +145,12 @@ namespace Tensorflow add_to_collection(name, value); } + private void _check_not_finalized() + { + if (_finalized) + throw new RuntimeError("Graph is finalized and cannot be modified."); + } + public unsafe Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, string name = "", Dictionary attrs = null, OpDef op_def = null) diff --git a/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs index b4cb952e..5b1f07e2 100644 --- a/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs +++ b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs @@ -85,8 +85,8 @@ namespace Tensorflow // Add a placeholder string tensor for the filename. var filename_tensor = array_ops.placeholder_with_default(string.IsNullOrEmpty(filename) ? "model" : filename, shape: new int[0], name: "filename"); - filename_tensor = gen_array_ops.placeholder_with_default(filename_tensor, shape: new int[0], name: "Const"); // Keep the name "Const" for backwards compatibility. + filename_tensor = gen_array_ops.placeholder_with_default(filename_tensor, shape: new int[0], name: "Const"); // Add the save ops. if (sharded) @@ -106,10 +106,19 @@ namespace Tensorflow var check_collection_list = graph.get_all_collection_keys(); foreach (var collection_type in check_collection_list) { - foreach (var element in graph.get_collection(collection_type) as IList) + var cols = graph.get_collection(collection_type); + switch (cols) { - + case List values: + foreach (var element in values) ; + break; + case List values: + foreach (var element in values) ; + break; + default: + throw new NotImplementedException("_build_internal.check_collection_list"); } + } return new SaverDef() diff --git a/src/TensorFlowNET.Core/Train/Saving/Saver.cs b/src/TensorFlowNET.Core/Train/Saving/Saver.cs index 816ffea7..5d436e00 100644 --- a/src/TensorFlowNET.Core/Train/Saving/Saver.cs +++ b/src/TensorFlowNET.Core/Train/Saving/Saver.cs @@ -193,7 +193,7 @@ namespace Tensorflow return _is_empty ? string.Empty : model_checkpoint_path; } - public Saver import_meta_graph(string meta_graph_or_file, + public (Saver, object) import_meta_graph(string meta_graph_or_file, bool clear_devices = false, string import_scope = "") { diff --git a/src/TensorFlowNET.Core/Train/Saving/saver.py.cs b/src/TensorFlowNET.Core/Train/Saving/saver.py.cs index 344e2078..bb22702f 100644 --- a/src/TensorFlowNET.Core/Train/Saving/saver.py.cs +++ b/src/TensorFlowNET.Core/Train/Saving/saver.py.cs @@ -6,22 +6,48 @@ namespace Tensorflow { public class saver { - public static Saver _import_meta_graph_with_return_elements(string meta_graph_or_file, + public static (Saver, object) _import_meta_graph_with_return_elements(string meta_graph_or_file, bool clear_devices = false, string import_scope = "", string[] return_elements = null) { var meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file); - meta_graph.import_scoped_meta_graph_with_return_elements( + var imported_vars = meta_graph.import_scoped_meta_graph_with_return_elements( meta_graph_def, clear_devices: clear_devices, import_scope: import_scope, return_elements: return_elements); - return null; - /*var (imported_vars, imported_return_elements) = ( - , false);*/ + var saver = _create_saver_from_imported_meta_graph( + meta_graph_def, import_scope, imported_vars); + + return (saver, null); + } + + public static Saver _create_saver_from_imported_meta_graph(MetaGraphDef meta_graph_def, + string import_scope, + (Dictionary, ITensorOrOperation[]) imported_vars) + { + if(meta_graph_def.SaverDef != null) + { + throw new NotImplementedException("_create_saver_from_imported_meta_graph"); + } + else + { + if(variables._all_saveable_objects(scope: import_scope).Length > 0) + { + // Return the default saver instance for all graph variables. + return new Saver(); + } + else + { + // If no graph variables exist, then a Saver cannot be constructed. + Console.WriteLine("Saver not created because there are no variables in the" + + " graph to restore"); + return null; + } + } } } } diff --git a/src/TensorFlowNET.Core/Train/tf.optimizers.cs b/src/TensorFlowNET.Core/Train/tf.optimizers.cs index 8579047a..e82d3dc0 100644 --- a/src/TensorFlowNET.Core/Train/tf.optimizers.cs +++ b/src/TensorFlowNET.Core/Train/tf.optimizers.cs @@ -19,7 +19,7 @@ namespace Tensorflow bool clear_devices = false, string import_scope = "") => saver._import_meta_graph_with_return_elements(meta_graph_or_file, clear_devices, - import_scope); + import_scope).Item1; } } } diff --git a/src/TensorFlowNET.Core/ops.GraphKeys.cs b/src/TensorFlowNET.Core/ops.GraphKeys.cs index 4f918b2a..ef107ff3 100644 --- a/src/TensorFlowNET.Core/ops.GraphKeys.cs +++ b/src/TensorFlowNET.Core/ops.GraphKeys.cs @@ -28,7 +28,7 @@ namespace Tensorflow /// public static string GLOBAL_VARIABLES = "variables"; - public static string[] _VARIABLE_COLLECTIONS = new string[] { "trainable_variables" }; + public static string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables" }; /// /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. /// diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index b7ccfe86..026e7feb 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -440,5 +440,17 @@ namespace Tensorflow throw new NotImplementedException($"internal_convert_to_tensor: Can't convert {value.GetType().Name} to Tensor"); } } + + public static string strip_name_scope(string name, string export_scope = "") + { + if (!string.IsNullOrEmpty(export_scope)) + { + throw new NotImplementedException("ops.strip_name_scope"); + } + else + { + return name; + } + } } }