| @@ -47,6 +47,16 @@ namespace Tensorflow | |||||
| _graph_key = $"grap-key-{ops.uid()}/"; | _graph_key = $"grap-key-{ops.uid()}/"; | ||||
| } | } | ||||
| public Graph(IntPtr handle) | |||||
| { | |||||
| _handle = handle; | |||||
| Status = new Status(); | |||||
| _nodes_by_id = new Dictionary<int, ITensorOrOperation>(); | |||||
| _nodes_by_name = new Dictionary<string, ITensorOrOperation>(); | |||||
| _names_in_use = new Dictionary<string, int>(); | |||||
| _graph_key = $"grap-key-{ops.uid()}/"; | |||||
| } | |||||
| public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) | public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) | ||||
| { | { | ||||
| return _as_graph_element_locked(obj, allow_tensor, allow_operation); | return _as_graph_element_locked(obj, allow_tensor, allow_operation); | ||||
| @@ -254,6 +254,25 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_ImportGraphDefResultsReturnOutputs(IntPtr results, ref int num_outputs, ref IntPtr outputs); | public static extern void TF_ImportGraphDefResultsReturnOutputs(IntPtr results, ref int num_outputs, ref IntPtr outputs); | ||||
| /// <summary> | |||||
| /// This function creates a new TF_Session (which is created on success) using | |||||
| /// `session_options`, and then initializes state (restoring tensors and other | |||||
| /// assets) using `run_options`. | |||||
| /// </summary> | |||||
| /// <param name="session_options">const TF_SessionOptions*</param> | |||||
| /// <param name="run_options">const TF_Buffer*</param> | |||||
| /// <param name="export_dir">const char*</param> | |||||
| /// <param name="tags">const char* const*</param> | |||||
| /// <param name="tags_len">int</param> | |||||
| /// <param name="graph">TF_Graph*</param> | |||||
| /// <param name="meta_graph_def">TF_Buffer*</param> | |||||
| /// <param name="status">TF_Status*</param> | |||||
| /// <returns></returns> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern IntPtr TF_LoadSessionFromSavedModel(IntPtr session_options, IntPtr run_options, | |||||
| string export_dir, string[] tags, int tags_len, | |||||
| IntPtr graph, ref TF_Buffer meta_graph_def, IntPtr status); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TF_NewGraph(); | public static extern IntPtr TF_NewGraph(); | ||||
| @@ -17,11 +17,23 @@ namespace Tensorflow | |||||
| (logits, labels, weights)), | (logits, labels, weights)), | ||||
| namescope => | namescope => | ||||
| { | { | ||||
| (labels, logits, weights) = _remove_squeezable_dimensions( | |||||
| labels, logits, weights, expected_rank_diff: 1); | |||||
| }); | }); | ||||
| throw new NotImplementedException("sparse_softmax_cross_entropy"); | throw new NotImplementedException("sparse_softmax_cross_entropy"); | ||||
| } | } | ||||
| public (Tensor, Tensor, float) _remove_squeezable_dimensions(Tensor labels, | |||||
| Tensor predictions, | |||||
| float weights = 0, | |||||
| int expected_rank_diff = 0) | |||||
| { | |||||
| (labels, predictions, weights) = confusion_matrix.remove_squeezable_dimensions( | |||||
| labels, predictions, expected_rank_diff: expected_rank_diff); | |||||
| throw new NotImplementedException("_remove_squeezable_dimensions"); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,17 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class confusion_matrix | |||||
| { | |||||
| public static (Tensor, Tensor, float) remove_squeezable_dimensions(Tensor labels, | |||||
| Tensor predictions, | |||||
| int expected_rank_diff = 0, | |||||
| string name = "") | |||||
| { | |||||
| throw new NotImplementedException("remove_squeezable_dimensions"); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -36,6 +36,25 @@ namespace Tensorflow | |||||
| Status.Check(true); | Status.Check(true); | ||||
| } | } | ||||
| public static Session LoadFromSavedModel(string path) | |||||
| { | |||||
| var graph = c_api.TF_NewGraph(); | |||||
| var status = new Status(); | |||||
| var opt = c_api.TF_NewSessionOptions(); | |||||
| var buffer = new TF_Buffer(); | |||||
| var sess = c_api.TF_LoadSessionFromSavedModel(opt, IntPtr.Zero, path, new string[0], 0, graph, ref buffer, status); | |||||
| //var bytes = new Buffer(buffer.data).Data; | |||||
| //var meta_graph = MetaGraphDef.Parser.ParseFrom(bytes); | |||||
| status.Check(); | |||||
| tf.g = new Graph(graph); | |||||
| return sess; | |||||
| } | |||||
| public static implicit operator IntPtr(Session session) => session._handle; | public static implicit operator IntPtr(Session session) => session._handle; | ||||
| public static implicit operator Session(IntPtr handle) => new Session(handle); | public static implicit operator Session(IntPtr handle) => new Session(handle); | ||||
| @@ -4,7 +4,7 @@ | |||||
| <TargetFramework>netstandard2.0</TargetFramework> | <TargetFramework>netstandard2.0</TargetFramework> | ||||
| <AssemblyName>TensorFlow.NET</AssemblyName> | <AssemblyName>TensorFlow.NET</AssemblyName> | ||||
| <RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
| <Version>0.2.0</Version> | |||||
| <Version>0.3.0</Version> | |||||
| <Authors>Haiping Chen</Authors> | <Authors>Haiping Chen</Authors> | ||||
| <Company>SciSharp STACK</Company> | <Company>SciSharp STACK</Company> | ||||
| <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | ||||
| @@ -16,13 +16,13 @@ | |||||
| <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET</PackageTags> | <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET</PackageTags> | ||||
| <Description>Google's TensorFlow binding in .NET Standard. | <Description>Google's TensorFlow binding in .NET Standard. | ||||
| Docs: https://tensorflownet.readthedocs.io</Description> | Docs: https://tensorflownet.readthedocs.io</Description> | ||||
| <AssemblyVersion>0.2.0.0</AssemblyVersion> | |||||
| <AssemblyVersion>0.3.0.0</AssemblyVersion> | |||||
| <PackageReleaseNotes>Added a bunch of APIs. | <PackageReleaseNotes>Added a bunch of APIs. | ||||
| Fixed String tensor creation bug. | Fixed String tensor creation bug. | ||||
| Upgraded to TensorFlow 1.13 RC-1. | Upgraded to TensorFlow 1.13 RC-1. | ||||
| </PackageReleaseNotes> | </PackageReleaseNotes> | ||||
| <LangVersion>7.2</LangVersion> | <LangVersion>7.2</LangVersion> | ||||
| <FileVersion>0.2.0.0</FileVersion> | |||||
| <FileVersion>0.3.0.0</FileVersion> | |||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | ||||
| @@ -26,6 +26,15 @@ namespace TensorFlowNET.UnitTest | |||||
| }); | }); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void ImportSavedModel() | |||||
| { | |||||
| with<Session>(Session.LoadFromSavedModel("mobilenet"), sess => | |||||
| { | |||||
| }); | |||||
| } | |||||
| [TestMethod] | [TestMethod] | ||||
| public void Save1() | public void Save1() | ||||
| { | { | ||||