| @@ -18,7 +18,7 @@ namespace Tensorflow | |||
| return meta_graph_def; | |||
| } | |||
| public static (RefVariable[], string[]) import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file, | |||
| public static (Dictionary<string, RefVariable>, ITensorOrOperation[]) import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file, | |||
| bool clear_devices = false, | |||
| string import_scope = "", | |||
| Dictionary<string, Tensor> input_map = null, | |||
| @@ -89,7 +89,7 @@ namespace Tensorflow | |||
| variable = new RefVariable(variable_def: proto, import_scope: scope_to_prepend_to_names); | |||
| variable_objects[value] = variable; | |||
| } | |||
| variable = variable_objects[value]; | |||
| graph.add_to_collection(col.Key, variable); | |||
| } | |||
| } | |||
| @@ -102,7 +102,12 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| return (null, null); | |||
| var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, | |||
| scope: scope_to_prepend_to_names) as List<RefVariable>; | |||
| var var_list = new Dictionary<string, RefVariable>(); | |||
| variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v); | |||
| return (var_list, imported_return_elements); | |||
| } | |||
| /// <summary> | |||
| @@ -26,6 +26,12 @@ namespace Tensorflow | |||
| public string _graph_key; | |||
| public Status Status { get; } | |||
| /// <summary> | |||
| /// True if the graph is considered "finalized". In that case no | |||
| /// new operations can be added. | |||
| /// </summary> | |||
| private bool _finalized = false; | |||
| /// <summary> | |||
| /// Arbitrary collections of objects. | |||
| /// </summary> | |||
| @@ -126,6 +132,7 @@ namespace Tensorflow | |||
| public void add_to_collection<T>(string name, T value) | |||
| { | |||
| _check_not_finalized(); | |||
| if (_collections.ContainsKey(name)) | |||
| (_collections[name] as List<T>).Add(value); | |||
| else | |||
| @@ -138,6 +145,12 @@ namespace Tensorflow | |||
| add_to_collection(name, value); | |||
| } | |||
| private void _check_not_finalized() | |||
| { | |||
| if (_finalized) | |||
| throw new RuntimeError("Graph is finalized and cannot be modified."); | |||
| } | |||
| public unsafe Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, | |||
| TF_DataType[] input_types = null, string name = "", | |||
| Dictionary<string, AttrValue> attrs = null, OpDef op_def = null) | |||
| @@ -85,8 +85,8 @@ namespace Tensorflow | |||
| // Add a placeholder string tensor for the filename. | |||
| var filename_tensor = array_ops.placeholder_with_default(string.IsNullOrEmpty(filename) ? "model" : filename, shape: new int[0], name: "filename"); | |||
| filename_tensor = gen_array_ops.placeholder_with_default(filename_tensor, shape: new int[0], name: "Const"); | |||
| // Keep the name "Const" for backwards compatibility. | |||
| filename_tensor = gen_array_ops.placeholder_with_default(filename_tensor, shape: new int[0], name: "Const"); | |||
| // Add the save ops. | |||
| if (sharded) | |||
| @@ -106,10 +106,19 @@ namespace Tensorflow | |||
| var check_collection_list = graph.get_all_collection_keys(); | |||
| foreach (var collection_type in check_collection_list) | |||
| { | |||
| foreach (var element in graph.get_collection(collection_type) as IList<RefVariable>) | |||
| var cols = graph.get_collection(collection_type); | |||
| switch (cols) | |||
| { | |||
| case List<RefVariable> values: | |||
| foreach (var element in values) ; | |||
| break; | |||
| case List<ITensorOrOperation> values: | |||
| foreach (var element in values) ; | |||
| break; | |||
| default: | |||
| throw new NotImplementedException("_build_internal.check_collection_list"); | |||
| } | |||
| } | |||
| return new SaverDef() | |||
| @@ -193,7 +193,7 @@ namespace Tensorflow | |||
| return _is_empty ? string.Empty : model_checkpoint_path; | |||
| } | |||
| public Saver import_meta_graph(string meta_graph_or_file, | |||
| public (Saver, object) import_meta_graph(string meta_graph_or_file, | |||
| bool clear_devices = false, | |||
| string import_scope = "") | |||
| { | |||
| @@ -6,22 +6,48 @@ namespace Tensorflow | |||
| { | |||
| public class saver | |||
| { | |||
| public static Saver _import_meta_graph_with_return_elements(string meta_graph_or_file, | |||
| public static (Saver, object) _import_meta_graph_with_return_elements(string meta_graph_or_file, | |||
| bool clear_devices = false, | |||
| string import_scope = "", | |||
| string[] return_elements = null) | |||
| { | |||
| var meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file); | |||
| meta_graph.import_scoped_meta_graph_with_return_elements( | |||
| var imported_vars = meta_graph.import_scoped_meta_graph_with_return_elements( | |||
| meta_graph_def, | |||
| clear_devices: clear_devices, | |||
| import_scope: import_scope, | |||
| return_elements: return_elements); | |||
| return null; | |||
| /*var (imported_vars, imported_return_elements) = ( | |||
| , false);*/ | |||
| var saver = _create_saver_from_imported_meta_graph( | |||
| meta_graph_def, import_scope, imported_vars); | |||
| return (saver, null); | |||
| } | |||
| public static Saver _create_saver_from_imported_meta_graph(MetaGraphDef meta_graph_def, | |||
| string import_scope, | |||
| (Dictionary<string, RefVariable>, ITensorOrOperation[]) imported_vars) | |||
| { | |||
| if(meta_graph_def.SaverDef != null) | |||
| { | |||
| throw new NotImplementedException("_create_saver_from_imported_meta_graph"); | |||
| } | |||
| else | |||
| { | |||
| if(variables._all_saveable_objects(scope: import_scope).Length > 0) | |||
| { | |||
| // Return the default saver instance for all graph variables. | |||
| return new Saver(); | |||
| } | |||
| else | |||
| { | |||
| // If no graph variables exist, then a Saver cannot be constructed. | |||
| Console.WriteLine("Saver not created because there are no variables in the" + | |||
| " graph to restore"); | |||
| return null; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -19,7 +19,7 @@ namespace Tensorflow | |||
| bool clear_devices = false, | |||
| string import_scope = "") => saver._import_meta_graph_with_return_elements(meta_graph_or_file, | |||
| clear_devices, | |||
| import_scope); | |||
| import_scope).Item1; | |||
| } | |||
| } | |||
| } | |||
| @@ -28,7 +28,7 @@ namespace Tensorflow | |||
| /// </summary> | |||
| public static string GLOBAL_VARIABLES = "variables"; | |||
| public static string[] _VARIABLE_COLLECTIONS = new string[] { "trainable_variables" }; | |||
| public static string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables" }; | |||
| /// <summary> | |||
| /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | |||
| /// </summary> | |||
| @@ -440,5 +440,17 @@ namespace Tensorflow | |||
| throw new NotImplementedException($"internal_convert_to_tensor: Can't convert {value.GetType().Name} to Tensor"); | |||
| } | |||
| } | |||
| public static string strip_name_scope(string name, string export_scope = "") | |||
| { | |||
| if (!string.IsNullOrEmpty(export_scope)) | |||
| { | |||
| throw new NotImplementedException("ops.strip_name_scope"); | |||
| } | |||
| else | |||
| { | |||
| return name; | |||
| } | |||
| } | |||
| } | |||
| } | |||