Browse Source

tf.trainer.load_graph, tf.trainer.freeze_graph

tags/v0.20
Oceania2018 5 years ago
parent
commit
170f3a774e
3 changed files with 44 additions and 4 deletions
  1. +6
    -0
      src/TensorFlowNET.Core/APIs/tf.train.cs
  2. +5
    -4
      src/TensorFlowNET.Core/TensorFlow.Binding.csproj
  3. +33
    -0
      src/TensorFlowNET.Core/Training/Saving/saver.py.cs

+ 6
- 0
src/TensorFlowNET.Core/APIs/tf.train.cs View File

@@ -53,6 +53,12 @@ namespace Tensorflow
public string write_graph(Graph graph, string logdir, string name, bool as_text = true)
=> graph_io.write_graph(graph, logdir, name, as_text);

public Graph load_graph(string freeze_graph_pb)
=> saver.load_graph(freeze_graph_pb);

public string freeze_graph(string checkpoint_dir, string output_pb_name)
=> saver.freeze_graph(checkpoint_dir, output_pb_name);

public Saver import_meta_graph(string meta_graph_or_file,
bool clear_devices = false,
string import_scope = "") => saver._import_meta_graph_with_return_elements(meta_graph_or_file,


+ 5
- 4
src/TensorFlowNET.Core/TensorFlow.Binding.csproj View File

@@ -5,7 +5,7 @@
<AssemblyName>TensorFlow.NET</AssemblyName>
<RootNamespace>Tensorflow</RootNamespace>
<TargetTensorFlow>1.14.1</TargetTensorFlow>
<Version>0.15.0</Version>
<Version>0.14.1</Version>
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
<Company>SciSharp STACK</Company>
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
@@ -18,11 +18,12 @@
<Description>Google's TensorFlow full binding in .NET Standard.
Building, training and infering deep learning models.
https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.15.0.0</AssemblyVersion>
<AssemblyVersion>0.14.1.0</AssemblyVersion>
<PackageReleaseNotes>Changes since v0.14.0:
1: Add TransformGraphWithStringInputs.</PackageReleaseNotes>
1: Add TransformGraphWithStringInputs.
2: tf.trainer.load_graph, tf.trainer.freeze_graph</PackageReleaseNotes>
<LangVersion>7.3</LangVersion>
<FileVersion>0.15.0.0</FileVersion>
<FileVersion>0.14.1.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile>
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
<SignAssembly>true</SignAssembly>


+ 33
- 0
src/TensorFlowNET.Core/Training/Saving/saver.py.cs View File

@@ -14,9 +14,12 @@
limitations under the License.
******************************************************************************/

using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using static Tensorflow.Binding;

namespace Tensorflow
{
@@ -81,5 +84,35 @@ namespace Tensorflow
}
}
}

public static string freeze_graph(string checkpoint_dir, string output_pb_name)
{
var checkpoint = checkpoint_management.latest_checkpoint(checkpoint_dir);
if (!File.Exists($"{checkpoint}.meta")) return null;

string output_pb = Path.GetFullPath(Path.Combine(checkpoint_dir, "../", $"{output_pb_name}.pb"));

using (var graph = tf.Graph())
using (var sess = tf.Session(graph))
{
var saver = tf.train.import_meta_graph($"{checkpoint}.meta", clear_devices: true);
saver.restore(sess, checkpoint);
var output_graph_def = tf.graph_util.convert_variables_to_constants(sess,
graph.as_graph_def(),
new string[] { "output/ArgMax" });
Console.WriteLine($"Froze {output_graph_def.Node.Count} nodes.");
File.WriteAllBytes(output_pb, output_graph_def.ToByteArray());
return output_pb;
}
}

public static Graph load_graph(string freeze_graph_pb, string name = "")
{
var bytes = File.ReadAllBytes(freeze_graph_pb);
var graph = tf.Graph().as_default();
importer.import_graph_def(GraphDef.Parser.ParseFrom(bytes),
name: name);
return graph;
}
}
}

Loading…
Cancel
Save