| @@ -18,7 +18,7 @@ namespace Tensorflow | |||||
| return meta_graph_def; | return meta_graph_def; | ||||
| } | } | ||||
| public static (RefVariable[], string[]) import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file, | |||||
| public static (Dictionary<string, RefVariable>, ITensorOrOperation[]) import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file, | |||||
| bool clear_devices = false, | bool clear_devices = false, | ||||
| string import_scope = "", | string import_scope = "", | ||||
| Dictionary<string, Tensor> input_map = null, | Dictionary<string, Tensor> input_map = null, | ||||
| @@ -89,7 +89,7 @@ namespace Tensorflow | |||||
| variable = new RefVariable(variable_def: proto, import_scope: scope_to_prepend_to_names); | variable = new RefVariable(variable_def: proto, import_scope: scope_to_prepend_to_names); | ||||
| variable_objects[value] = variable; | variable_objects[value] = variable; | ||||
| } | } | ||||
| variable = variable_objects[value]; | |||||
| graph.add_to_collection(col.Key, variable); | graph.add_to_collection(col.Key, variable); | ||||
| } | } | ||||
| } | } | ||||
| @@ -102,7 +102,12 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| return (null, null); | |||||
| var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, | |||||
| scope: scope_to_prepend_to_names) as List<RefVariable>; | |||||
| var var_list = new Dictionary<string, RefVariable>(); | |||||
| variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v); | |||||
| return (var_list, imported_return_elements); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -26,6 +26,12 @@ namespace Tensorflow | |||||
| public string _graph_key; | public string _graph_key; | ||||
| public Status Status { get; } | public Status Status { get; } | ||||
| /// <summary> | |||||
| /// True if the graph is considered "finalized". In that case no | |||||
| /// new operations can be added. | |||||
| /// </summary> | |||||
| private bool _finalized = false; | |||||
| /// <summary> | /// <summary> | ||||
| /// Arbitrary collections of objects. | /// Arbitrary collections of objects. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -126,6 +132,7 @@ namespace Tensorflow | |||||
| public void add_to_collection<T>(string name, T value) | public void add_to_collection<T>(string name, T value) | ||||
| { | { | ||||
| _check_not_finalized(); | |||||
| if (_collections.ContainsKey(name)) | if (_collections.ContainsKey(name)) | ||||
| (_collections[name] as List<T>).Add(value); | (_collections[name] as List<T>).Add(value); | ||||
| else | else | ||||
| @@ -138,6 +145,12 @@ namespace Tensorflow | |||||
| add_to_collection(name, value); | add_to_collection(name, value); | ||||
| } | } | ||||
| private void _check_not_finalized() | |||||
| { | |||||
| if (_finalized) | |||||
| throw new RuntimeError("Graph is finalized and cannot be modified."); | |||||
| } | |||||
| public unsafe Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, | public unsafe Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, | ||||
| TF_DataType[] input_types = null, string name = "", | TF_DataType[] input_types = null, string name = "", | ||||
| Dictionary<string, AttrValue> attrs = null, OpDef op_def = null) | Dictionary<string, AttrValue> attrs = null, OpDef op_def = null) | ||||
| @@ -85,8 +85,8 @@ namespace Tensorflow | |||||
| // Add a placeholder string tensor for the filename. | // Add a placeholder string tensor for the filename. | ||||
| var filename_tensor = array_ops.placeholder_with_default(string.IsNullOrEmpty(filename) ? "model" : filename, shape: new int[0], name: "filename"); | var filename_tensor = array_ops.placeholder_with_default(string.IsNullOrEmpty(filename) ? "model" : filename, shape: new int[0], name: "filename"); | ||||
| filename_tensor = gen_array_ops.placeholder_with_default(filename_tensor, shape: new int[0], name: "Const"); | |||||
| // Keep the name "Const" for backwards compatibility. | // Keep the name "Const" for backwards compatibility. | ||||
| filename_tensor = gen_array_ops.placeholder_with_default(filename_tensor, shape: new int[0], name: "Const"); | |||||
| // Add the save ops. | // Add the save ops. | ||||
| if (sharded) | if (sharded) | ||||
| @@ -106,10 +106,19 @@ namespace Tensorflow | |||||
| var check_collection_list = graph.get_all_collection_keys(); | var check_collection_list = graph.get_all_collection_keys(); | ||||
| foreach (var collection_type in check_collection_list) | foreach (var collection_type in check_collection_list) | ||||
| { | { | ||||
| foreach (var element in graph.get_collection(collection_type) as IList<RefVariable>) | |||||
| var cols = graph.get_collection(collection_type); | |||||
| switch (cols) | |||||
| { | { | ||||
| case List<RefVariable> values: | |||||
| foreach (var element in values) ; | |||||
| break; | |||||
| case List<ITensorOrOperation> values: | |||||
| foreach (var element in values) ; | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException("_build_internal.check_collection_list"); | |||||
| } | } | ||||
| } | } | ||||
| return new SaverDef() | return new SaverDef() | ||||
| @@ -193,7 +193,7 @@ namespace Tensorflow | |||||
| return _is_empty ? string.Empty : model_checkpoint_path; | return _is_empty ? string.Empty : model_checkpoint_path; | ||||
| } | } | ||||
| public Saver import_meta_graph(string meta_graph_or_file, | |||||
| public (Saver, object) import_meta_graph(string meta_graph_or_file, | |||||
| bool clear_devices = false, | bool clear_devices = false, | ||||
| string import_scope = "") | string import_scope = "") | ||||
| { | { | ||||
| @@ -6,22 +6,48 @@ namespace Tensorflow | |||||
| { | { | ||||
| public class saver | public class saver | ||||
| { | { | ||||
| public static Saver _import_meta_graph_with_return_elements(string meta_graph_or_file, | |||||
| public static (Saver, object) _import_meta_graph_with_return_elements(string meta_graph_or_file, | |||||
| bool clear_devices = false, | bool clear_devices = false, | ||||
| string import_scope = "", | string import_scope = "", | ||||
| string[] return_elements = null) | string[] return_elements = null) | ||||
| { | { | ||||
| var meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file); | var meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file); | ||||
| meta_graph.import_scoped_meta_graph_with_return_elements( | |||||
| var imported_vars = meta_graph.import_scoped_meta_graph_with_return_elements( | |||||
| meta_graph_def, | meta_graph_def, | ||||
| clear_devices: clear_devices, | clear_devices: clear_devices, | ||||
| import_scope: import_scope, | import_scope: import_scope, | ||||
| return_elements: return_elements); | return_elements: return_elements); | ||||
| return null; | |||||
| /*var (imported_vars, imported_return_elements) = ( | |||||
| , false);*/ | |||||
| var saver = _create_saver_from_imported_meta_graph( | |||||
| meta_graph_def, import_scope, imported_vars); | |||||
| return (saver, null); | |||||
| } | |||||
| public static Saver _create_saver_from_imported_meta_graph(MetaGraphDef meta_graph_def, | |||||
| string import_scope, | |||||
| (Dictionary<string, RefVariable>, ITensorOrOperation[]) imported_vars) | |||||
| { | |||||
| if(meta_graph_def.SaverDef != null) | |||||
| { | |||||
| throw new NotImplementedException("_create_saver_from_imported_meta_graph"); | |||||
| } | |||||
| else | |||||
| { | |||||
| if(variables._all_saveable_objects(scope: import_scope).Length > 0) | |||||
| { | |||||
| // Return the default saver instance for all graph variables. | |||||
| return new Saver(); | |||||
| } | |||||
| else | |||||
| { | |||||
| // If no graph variables exist, then a Saver cannot be constructed. | |||||
| Console.WriteLine("Saver not created because there are no variables in the" + | |||||
| " graph to restore"); | |||||
| return null; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -19,7 +19,7 @@ namespace Tensorflow | |||||
| bool clear_devices = false, | bool clear_devices = false, | ||||
| string import_scope = "") => saver._import_meta_graph_with_return_elements(meta_graph_or_file, | string import_scope = "") => saver._import_meta_graph_with_return_elements(meta_graph_or_file, | ||||
| clear_devices, | clear_devices, | ||||
| import_scope); | |||||
| import_scope).Item1; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -28,7 +28,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public static string GLOBAL_VARIABLES = "variables"; | public static string GLOBAL_VARIABLES = "variables"; | ||||
| public static string[] _VARIABLE_COLLECTIONS = new string[] { "trainable_variables" }; | |||||
| public static string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables" }; | |||||
| /// <summary> | /// <summary> | ||||
| /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -440,5 +440,17 @@ namespace Tensorflow | |||||
| throw new NotImplementedException($"internal_convert_to_tensor: Can't convert {value.GetType().Name} to Tensor"); | throw new NotImplementedException($"internal_convert_to_tensor: Can't convert {value.GetType().Name} to Tensor"); | ||||
| } | } | ||||
| } | } | ||||
| public static string strip_name_scope(string name, string export_scope = "") | |||||
| { | |||||
| if (!string.IsNullOrEmpty(export_scope)) | |||||
| { | |||||
| throw new NotImplementedException("ops.strip_name_scope"); | |||||
| } | |||||
| else | |||||
| { | |||||
| return name; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||