From 4eccba4cf5a280d88d5385b24ee708980bdb4f19 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Wed, 13 Feb 2019 23:05:04 -0600 Subject: [PATCH] add LoadFromSavedModel --- src/TensorFlowNET.Core/Graphs/Graph.cs | 10 ++++++++++ src/TensorFlowNET.Core/Graphs/c_api.graph.cs | 19 +++++++++++++++++++ .../Operations/Losses/losses_impl.py.cs | 14 +++++++++++++- .../Operations/confusion_matrix.py.cs | 17 +++++++++++++++++ src/TensorFlowNET.Core/Sessions/Session.cs | 19 +++++++++++++++++++ .../TensorFlowNET.Core.csproj | 6 +++--- test/TensorFlowNET.UnitTest/TrainSaverTest.cs | 9 +++++++++ 7 files changed, 90 insertions(+), 4 deletions(-) create mode 100644 src/TensorFlowNET.Core/Operations/confusion_matrix.py.cs diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index da80f021..5455dae6 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -47,6 +47,16 @@ namespace Tensorflow _graph_key = $"grap-key-{ops.uid()}/"; } + public Graph(IntPtr handle) + { + _handle = handle; + Status = new Status(); + _nodes_by_id = new Dictionary(); + _nodes_by_name = new Dictionary(); + _names_in_use = new Dictionary(); + _graph_key = $"grap-key-{ops.uid()}/"; + } + public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) { return _as_graph_element_locked(obj, allow_tensor, allow_operation); diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs index 5e15d507..7ea1e2f1 100644 --- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs +++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs @@ -254,6 +254,25 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern void TF_ImportGraphDefResultsReturnOutputs(IntPtr results, ref int num_outputs, ref IntPtr outputs); + /// + /// This function creates a new TF_Session (which is created on success) using + /// `session_options`, and then initializes state (restoring tensors and other + /// assets) using `run_options`. + /// + /// const TF_SessionOptions* + /// const TF_Buffer* + /// const char* + /// const char* const* + /// int + /// TF_Graph* + /// TF_Buffer* + /// TF_Status* + /// + [DllImport(TensorFlowLibName)] + public static extern IntPtr TF_LoadSessionFromSavedModel(IntPtr session_options, IntPtr run_options, + string export_dir, string[] tags, int tags_len, + IntPtr graph, ref TF_Buffer meta_graph_def, IntPtr status); + [DllImport(TensorFlowLibName)] public static extern IntPtr TF_NewGraph(); diff --git a/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs b/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs index 654d9f10..e45074b0 100644 --- a/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs +++ b/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs @@ -17,11 +17,23 @@ namespace Tensorflow (logits, labels, weights)), namescope => { - + (labels, logits, weights) = _remove_squeezable_dimensions( + labels, logits, weights, expected_rank_diff: 1); }); throw new NotImplementedException("sparse_softmax_cross_entropy"); } + + public (Tensor, Tensor, float) _remove_squeezable_dimensions(Tensor labels, + Tensor predictions, + float weights = 0, + int expected_rank_diff = 0) + { + (labels, predictions, weights) = confusion_matrix.remove_squeezable_dimensions( + labels, predictions, expected_rank_diff: expected_rank_diff); + + throw new NotImplementedException("_remove_squeezable_dimensions"); + } } } diff --git a/src/TensorFlowNET.Core/Operations/confusion_matrix.py.cs b/src/TensorFlowNET.Core/Operations/confusion_matrix.py.cs new file mode 100644 index 00000000..de26f7e1 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/confusion_matrix.py.cs @@ -0,0 +1,17 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class confusion_matrix + { + public static (Tensor, Tensor, float) remove_squeezable_dimensions(Tensor labels, + Tensor predictions, + int expected_rank_diff = 0, + string name = "") + { + throw new NotImplementedException("remove_squeezable_dimensions"); + } + } +} diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs index 6aa0707a..a4baac54 100644 --- a/src/TensorFlowNET.Core/Sessions/Session.cs +++ b/src/TensorFlowNET.Core/Sessions/Session.cs @@ -36,6 +36,25 @@ namespace Tensorflow Status.Check(true); } + public static Session LoadFromSavedModel(string path) + { + var graph = c_api.TF_NewGraph(); + var status = new Status(); + var opt = c_api.TF_NewSessionOptions(); + + var buffer = new TF_Buffer(); + var sess = c_api.TF_LoadSessionFromSavedModel(opt, IntPtr.Zero, path, new string[0], 0, graph, ref buffer, status); + + //var bytes = new Buffer(buffer.data).Data; + //var meta_graph = MetaGraphDef.Parser.ParseFrom(bytes); + + status.Check(); + + tf.g = new Graph(graph); + + return sess; + } + public static implicit operator IntPtr(Session session) => session._handle; public static implicit operator Session(IntPtr handle) => new Session(handle); diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 71af7b94..427b6314 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -4,7 +4,7 @@ netstandard2.0 TensorFlow.NET Tensorflow - 0.2.0 + 0.3.0 Haiping Chen SciSharp STACK true @@ -16,13 +16,13 @@ TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET Google's TensorFlow binding in .NET Standard. Docs: https://tensorflownet.readthedocs.io - 0.2.0.0 + 0.3.0.0 Added a bunch of APIs. Fixed String tensor creation bug. Upgraded to TensorFlow 1.13 RC-1. 7.2 - 0.2.0.0 + 0.3.0.0 diff --git a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs index 34a5808b..e0b4a9db 100644 --- a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs +++ b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs @@ -26,6 +26,15 @@ namespace TensorFlowNET.UnitTest }); } + [TestMethod] + public void ImportSavedModel() + { + with(Session.LoadFromSavedModel("mobilenet"), sess => + { + + }); + } + [TestMethod] public void Save1() {