Browse Source

tf.train.import_meta_graph passed.

tags/v0.8.0
haiping008 6 years ago
parent
commit
de082e1029
8 changed files with 79 additions and 14 deletions
  1. +8
    -3
      src/TensorFlowNET.Core/Framework/meta_graph.py.cs
  2. +13
    -0
      src/TensorFlowNET.Core/Graphs/Graph.cs
  3. +12
    -3
      src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Train/Saving/Saver.cs
  5. +31
    -5
      src/TensorFlowNET.Core/Train/Saving/saver.py.cs
  6. +1
    -1
      src/TensorFlowNET.Core/Train/tf.optimizers.cs
  7. +1
    -1
      src/TensorFlowNET.Core/ops.GraphKeys.cs
  8. +12
    -0
      src/TensorFlowNET.Core/ops.py.cs

+ 8
- 3
src/TensorFlowNET.Core/Framework/meta_graph.py.cs View File

@@ -18,7 +18,7 @@ namespace Tensorflow
return meta_graph_def;
}

public static (RefVariable[], string[]) import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file,
public static (Dictionary<string, RefVariable>, ITensorOrOperation[]) import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file,
bool clear_devices = false,
string import_scope = "",
Dictionary<string, Tensor> input_map = null,
@@ -89,7 +89,7 @@ namespace Tensorflow
variable = new RefVariable(variable_def: proto, import_scope: scope_to_prepend_to_names);
variable_objects[value] = variable;
}
variable = variable_objects[value];
graph.add_to_collection(col.Key, variable);
}
}
@@ -102,7 +102,12 @@ namespace Tensorflow
}
}

return (null, null);
var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
scope: scope_to_prepend_to_names) as List<RefVariable>;
var var_list = new Dictionary<string, RefVariable>();
variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v);

return (var_list, imported_return_elements);
}

/// <summary>


+ 13
- 0
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -26,6 +26,12 @@ namespace Tensorflow
public string _graph_key;
public Status Status { get; }

/// <summary>
/// True if the graph is considered "finalized". In that case no
/// new operations can be added.
/// </summary>
private bool _finalized = false;

/// <summary>
/// Arbitrary collections of objects.
/// </summary>
@@ -126,6 +132,7 @@ namespace Tensorflow

public void add_to_collection<T>(string name, T value)
{
_check_not_finalized();
if (_collections.ContainsKey(name))
(_collections[name] as List<T>).Add(value);
else
@@ -138,6 +145,12 @@ namespace Tensorflow
add_to_collection(name, value);
}

private void _check_not_finalized()
{
if (_finalized)
throw new RuntimeError("Graph is finalized and cannot be modified.");
}

public unsafe Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes,
TF_DataType[] input_types = null, string name = "",
Dictionary<string, AttrValue> attrs = null, OpDef op_def = null)


+ 12
- 3
src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs View File

@@ -85,8 +85,8 @@ namespace Tensorflow

// Add a placeholder string tensor for the filename.
var filename_tensor = array_ops.placeholder_with_default(string.IsNullOrEmpty(filename) ? "model" : filename, shape: new int[0], name: "filename");
filename_tensor = gen_array_ops.placeholder_with_default(filename_tensor, shape: new int[0], name: "Const");
// Keep the name "Const" for backwards compatibility.
filename_tensor = gen_array_ops.placeholder_with_default(filename_tensor, shape: new int[0], name: "Const");

// Add the save ops.
if (sharded)
@@ -106,10 +106,19 @@ namespace Tensorflow
var check_collection_list = graph.get_all_collection_keys();
foreach (var collection_type in check_collection_list)
{
foreach (var element in graph.get_collection(collection_type) as IList<RefVariable>)
var cols = graph.get_collection(collection_type);
switch (cols)
{

case List<RefVariable> values:
foreach (var element in values) ;
break;
case List<ITensorOrOperation> values:
foreach (var element in values) ;
break;
default:
throw new NotImplementedException("_build_internal.check_collection_list");
}
}

return new SaverDef()


+ 1
- 1
src/TensorFlowNET.Core/Train/Saving/Saver.cs View File

@@ -193,7 +193,7 @@ namespace Tensorflow
return _is_empty ? string.Empty : model_checkpoint_path;
}

public Saver import_meta_graph(string meta_graph_or_file,
public (Saver, object) import_meta_graph(string meta_graph_or_file,
bool clear_devices = false,
string import_scope = "")
{


+ 31
- 5
src/TensorFlowNET.Core/Train/Saving/saver.py.cs View File

@@ -6,22 +6,48 @@ namespace Tensorflow
{
public class saver
{
public static Saver _import_meta_graph_with_return_elements(string meta_graph_or_file,
public static (Saver, object) _import_meta_graph_with_return_elements(string meta_graph_or_file,
bool clear_devices = false,
string import_scope = "",
string[] return_elements = null)
{
var meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file);

meta_graph.import_scoped_meta_graph_with_return_elements(
var imported_vars = meta_graph.import_scoped_meta_graph_with_return_elements(
meta_graph_def,
clear_devices: clear_devices,
import_scope: import_scope,
return_elements: return_elements);

return null;
/*var (imported_vars, imported_return_elements) = (
, false);*/
var saver = _create_saver_from_imported_meta_graph(
meta_graph_def, import_scope, imported_vars);

return (saver, null);
}

public static Saver _create_saver_from_imported_meta_graph(MetaGraphDef meta_graph_def,
string import_scope,
(Dictionary<string, RefVariable>, ITensorOrOperation[]) imported_vars)
{
if(meta_graph_def.SaverDef != null)
{
throw new NotImplementedException("_create_saver_from_imported_meta_graph");
}
else
{
if(variables._all_saveable_objects(scope: import_scope).Length > 0)
{
// Return the default saver instance for all graph variables.
return new Saver();
}
else
{
// If no graph variables exist, then a Saver cannot be constructed.
Console.WriteLine("Saver not created because there are no variables in the" +
" graph to restore");
return null;
}
}
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Train/tf.optimizers.cs View File

@@ -19,7 +19,7 @@ namespace Tensorflow
bool clear_devices = false,
string import_scope = "") => saver._import_meta_graph_with_return_elements(meta_graph_or_file,
clear_devices,
import_scope);
import_scope).Item1;
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/ops.GraphKeys.cs View File

@@ -28,7 +28,7 @@ namespace Tensorflow
/// </summary>
public static string GLOBAL_VARIABLES = "variables";

public static string[] _VARIABLE_COLLECTIONS = new string[] { "trainable_variables" };
public static string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables" };
/// <summary>
/// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing.
/// </summary>


+ 12
- 0
src/TensorFlowNET.Core/ops.py.cs View File

@@ -440,5 +440,17 @@ namespace Tensorflow
throw new NotImplementedException($"internal_convert_to_tensor: Can't convert {value.GetType().Name} to Tensor");
}
}

public static string strip_name_scope(string name, string export_scope = "")
{
if (!string.IsNullOrEmpty(export_scope))
{
throw new NotImplementedException("ops.strip_name_scope");
}
else
{
return name;
}
}
}
}

Loading…
Cancel
Save