diff --git a/src/TensorFlowNET.Core/APIs/tf.train.cs b/src/TensorFlowNET.Core/APIs/tf.train.cs
index 8d9957ac..b9bc430d 100644
--- a/src/TensorFlowNET.Core/APIs/tf.train.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.train.cs
@@ -56,8 +56,8 @@ namespace Tensorflow
public Graph load_graph(string freeze_graph_pb)
=> saver.load_graph(freeze_graph_pb);
- public string freeze_graph(string checkpoint_dir, string output_pb_name)
- => saver.freeze_graph(checkpoint_dir, output_pb_name);
+ public string freeze_graph(string checkpoint_dir, string output_pb_name, string[] output_node_names)
+ => saver.freeze_graph(checkpoint_dir, output_pb_name, output_node_names);
public Saver import_meta_graph(string meta_graph_or_file,
bool clear_devices = false,
diff --git a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj
index cf3540c2..dd83ae65 100644
--- a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj
+++ b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj
@@ -5,7 +5,7 @@
TensorFlow.NET
Tensorflow
1.14.1
- 0.14.1
+ 0.14.1.1
Haiping Chen, Meinrad Recheis, Eli Belash
SciSharp STACK
true
@@ -18,12 +18,12 @@
Google's TensorFlow full binding in .NET Standard.
Building, training and infering deep learning models.
https://tensorflownet.readthedocs.io
- 0.14.1.0
+ 0.14.1.1
Changes since v0.14.0:
1: Add TransformGraphWithStringInputs.
2: tf.trainer.load_graph, tf.trainer.freeze_graph
7.3
- 0.14.1.0
+ 0.14.1.1
LICENSE
true
true
diff --git a/src/TensorFlowNET.Core/Training/Saving/saver.py.cs b/src/TensorFlowNET.Core/Training/Saving/saver.py.cs
index 5fa79d1c..5f119791 100644
--- a/src/TensorFlowNET.Core/Training/Saving/saver.py.cs
+++ b/src/TensorFlowNET.Core/Training/Saving/saver.py.cs
@@ -85,7 +85,9 @@ namespace Tensorflow
}
}
- public static string freeze_graph(string checkpoint_dir, string output_pb_name)
+ public static string freeze_graph(string checkpoint_dir,
+ string output_pb_name,
+ string[] output_node_names)
{
var checkpoint = checkpoint_management.latest_checkpoint(checkpoint_dir);
if (!File.Exists($"{checkpoint}.meta")) return null;
@@ -99,7 +101,7 @@ namespace Tensorflow
saver.restore(sess, checkpoint);
var output_graph_def = tf.graph_util.convert_variables_to_constants(sess,
graph.as_graph_def(),
- new string[] { "output/ArgMax" });
+ output_node_names);
Console.WriteLine($"Froze {output_graph_def.Node.Count} nodes.");
File.WriteAllBytes(output_pb, output_graph_def.ToByteArray());
return output_pb;