diff --git a/src/TensorFlowNET.Core/APIs/tf.train.cs b/src/TensorFlowNET.Core/APIs/tf.train.cs
index 862212ef..8d9957ac 100644
--- a/src/TensorFlowNET.Core/APIs/tf.train.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.train.cs
@@ -53,6 +53,12 @@ namespace Tensorflow
public string write_graph(Graph graph, string logdir, string name, bool as_text = true)
=> graph_io.write_graph(graph, logdir, name, as_text);
+ public Graph load_graph(string freeze_graph_pb)
+ => saver.load_graph(freeze_graph_pb);
+
+ public string freeze_graph(string checkpoint_dir, string output_pb_name)
+ => saver.freeze_graph(checkpoint_dir, output_pb_name);
+
public Saver import_meta_graph(string meta_graph_or_file,
bool clear_devices = false,
string import_scope = "") => saver._import_meta_graph_with_return_elements(meta_graph_or_file,
diff --git a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj
index 031d2b10..cf3540c2 100644
--- a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj
+++ b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj
@@ -5,7 +5,7 @@
TensorFlow.NET
Tensorflow
1.14.1
- 0.15.0
+ 0.14.1
Haiping Chen, Meinrad Recheis, Eli Belash
SciSharp STACK
true
@@ -18,11 +18,12 @@
Google's TensorFlow full binding in .NET Standard.
Building, training and infering deep learning models.
https://tensorflownet.readthedocs.io
- 0.15.0.0
+ 0.14.1.0
Changes since v0.14.0:
-1: Add TransformGraphWithStringInputs.
+1: Add TransformGraphWithStringInputs.
+2: tf.trainer.load_graph, tf.trainer.freeze_graph
7.3
- 0.15.0.0
+ 0.14.1.0
LICENSE
true
true
diff --git a/src/TensorFlowNET.Core/Training/Saving/saver.py.cs b/src/TensorFlowNET.Core/Training/Saving/saver.py.cs
index 2b75947b..5fa79d1c 100644
--- a/src/TensorFlowNET.Core/Training/Saving/saver.py.cs
+++ b/src/TensorFlowNET.Core/Training/Saving/saver.py.cs
@@ -14,9 +14,12 @@
limitations under the License.
******************************************************************************/
+using Google.Protobuf;
using System;
using System.Collections.Generic;
+using System.IO;
using System.Linq;
+using static Tensorflow.Binding;
namespace Tensorflow
{
@@ -81,5 +84,35 @@ namespace Tensorflow
}
}
}
+
+ public static string freeze_graph(string checkpoint_dir, string output_pb_name)
+ {
+ var checkpoint = checkpoint_management.latest_checkpoint(checkpoint_dir);
+ if (!File.Exists($"{checkpoint}.meta")) return null;
+
+ string output_pb = Path.GetFullPath(Path.Combine(checkpoint_dir, "../", $"{output_pb_name}.pb"));
+
+ using (var graph = tf.Graph())
+ using (var sess = tf.Session(graph))
+ {
+ var saver = tf.train.import_meta_graph($"{checkpoint}.meta", clear_devices: true);
+ saver.restore(sess, checkpoint);
+ var output_graph_def = tf.graph_util.convert_variables_to_constants(sess,
+ graph.as_graph_def(),
+ new string[] { "output/ArgMax" });
+ Console.WriteLine($"Froze {output_graph_def.Node.Count} nodes.");
+ File.WriteAllBytes(output_pb, output_graph_def.ToByteArray());
+ return output_pb;
+ }
+ }
+
+ public static Graph load_graph(string freeze_graph_pb, string name = "")
+ {
+ var bytes = File.ReadAllBytes(freeze_graph_pb);
+ var graph = tf.Graph().as_default();
+ importer.import_graph_def(GraphDef.Parser.ParseFrom(bytes),
+ name: name);
+ return graph;
+ }
}
}