| @@ -289,7 +289,7 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TF_LoadSessionFromSavedModel(SafeSessionOptionsHandle session_options, IntPtr run_options, | public static extern IntPtr TF_LoadSessionFromSavedModel(SafeSessionOptionsHandle session_options, IntPtr run_options, | ||||
| string export_dir, string[] tags, int tags_len, | string export_dir, string[] tags, int tags_len, | ||||
| IntPtr graph, ref TF_Buffer meta_graph_def, SafeStatusHandle status); | |||||
| IntPtr graph, IntPtr meta_graph_def, SafeStatusHandle status); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TF_NewGraph(); | public static extern IntPtr TF_NewGraph(); | ||||
| @@ -36,6 +36,12 @@ namespace Tensorflow | |||||
| protected byte[] _target; | protected byte[] _target; | ||||
| public Graph graph => _graph; | public Graph graph => _graph; | ||||
| public BaseSession(IntPtr handle, Graph g) | |||||
| { | |||||
| _handle = handle; | |||||
| _graph = g ?? ops.get_default_graph(); | |||||
| } | |||||
| public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null) | public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null) | ||||
| { | { | ||||
| _graph = g ?? ops.get_default_graph(); | _graph = g ?? ops.get_default_graph(); | ||||
| @@ -291,12 +297,8 @@ namespace Tensorflow | |||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | protected override void DisposeUnmanagedResources(IntPtr handle) | ||||
| { | { | ||||
| lock (Locks.ProcessWide) | |||||
| using (var status = new Status()) | |||||
| { | |||||
| c_api.TF_DeleteSession(handle, status.Handle); | |||||
| status.Check(true); | |||||
| } | |||||
| // c_api.TF_CloseSession(handle, tf.Status.Handle); | |||||
| c_api.TF_DeleteSession(handle, tf.Status.Handle); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -26,10 +26,8 @@ namespace Tensorflow | |||||
| public Session(string target = "", Graph g = null) : base(target, g, null) | public Session(string target = "", Graph g = null) : base(target, g, null) | ||||
| { } | { } | ||||
| public Session(IntPtr handle, Graph g = null) : base("", g, null) | |||||
| { | |||||
| _handle = handle; | |||||
| } | |||||
| public Session(IntPtr handle, Graph g = null) : base(handle, g) | |||||
| { } | |||||
| public Session(Graph g, ConfigProto config = null, Status s = null) : base("", g, config, s) | public Session(Graph g, ConfigProto config = null, Status s = null) : base("", g, config, s) | ||||
| { } | { } | ||||
| @@ -39,51 +37,29 @@ namespace Tensorflow | |||||
| return ops.set_default_session(this); | return ops.set_default_session(this); | ||||
| } | } | ||||
| [MethodImpl(MethodImplOptions.NoOptimization)] | |||||
| public static Session LoadFromSavedModel(string path) | public static Session LoadFromSavedModel(string path) | ||||
| { | { | ||||
| lock (Locks.ProcessWide) | |||||
| { | |||||
| var graph = c_api.TF_NewGraph(); | |||||
| using var status = new Status(); | |||||
| var opt = new SessionOptions(); | |||||
| var tags = new string[] { "serve" }; | |||||
| var buffer = new TF_Buffer(); | |||||
| IntPtr sess; | |||||
| try | |||||
| { | |||||
| sess = c_api.TF_LoadSessionFromSavedModel(opt.Handle, | |||||
| IntPtr.Zero, | |||||
| path, | |||||
| tags, | |||||
| tags.Length, | |||||
| graph, | |||||
| ref buffer, | |||||
| status.Handle); | |||||
| status.Check(true); | |||||
| } | |||||
| catch (TensorflowException ex) when (ex.Message.Contains("Could not find SavedModel")) | |||||
| { | |||||
| sess = c_api.TF_LoadSessionFromSavedModel(opt.Handle, | |||||
| IntPtr.Zero, | |||||
| Path.GetFullPath(path), | |||||
| tags, | |||||
| tags.Length, | |||||
| graph, | |||||
| ref buffer, | |||||
| status.Handle); | |||||
| status.Check(true); | |||||
| } | |||||
| // load graph bytes | |||||
| // var data = new byte[buffer.length]; | |||||
| // Marshal.Copy(buffer.data, data, 0, (int)buffer.length); | |||||
| // var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ | |||||
| return new Session(sess, g: new Graph(graph)).as_default(); | |||||
| } | |||||
| using var graph = new Graph(); | |||||
| using var status = new Status(); | |||||
| using var opt = c_api.TF_NewSessionOptions(); | |||||
| var tags = new string[] { "serve" }; | |||||
| var sess = c_api.TF_LoadSessionFromSavedModel(opt, | |||||
| IntPtr.Zero, | |||||
| path, | |||||
| tags, | |||||
| tags.Length, | |||||
| graph, | |||||
| IntPtr.Zero, | |||||
| status.Handle); | |||||
| status.Check(true); | |||||
| // load graph bytes | |||||
| // var data = new byte[buffer.length]; | |||||
| // Marshal.Copy(buffer.data, data, 0, (int)buffer.length); | |||||
| // var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ | |||||
| return new Session(sess, g: graph); | |||||
| } | } | ||||
| public static implicit operator IntPtr(Session session) => session._handle; | public static implicit operator IntPtr(Session session) => session._handle; | ||||
| @@ -21,6 +21,18 @@ namespace Tensorflow | |||||
| { | { | ||||
| public partial class c_api | public partial class c_api | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Close a session. | |||||
| /// | |||||
| /// Contacts any other processes associated with the session, if applicable. | |||||
| /// May not be called after TF_DeleteSession(). | |||||
| /// </summary> | |||||
| /// <param name="s"></param> | |||||
| /// <param name="status"></param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TF_CloseSession(IntPtr session, SafeStatusHandle status); | |||||
| /// <summary> | /// <summary> | ||||
| /// Destroy a session object. | /// Destroy a session object. | ||||
| /// | /// | ||||
| @@ -6,6 +6,7 @@ using System.Linq; | |||||
| using System.Reflection; | using System.Reflection; | ||||
| using System.Text; | using System.Text; | ||||
| using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Benchmark.Leak | namespace Tensorflow.Benchmark.Leak | ||||
| { | { | ||||
| @@ -18,13 +19,9 @@ namespace Tensorflow.Benchmark.Leak | |||||
| var modelDir = Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location); | var modelDir = Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location); | ||||
| var ClassifierModelPath = Path.Combine(modelDir, "Leak", "TestModel", "saved_model"); | var ClassifierModelPath = Path.Combine(modelDir, "Leak", "TestModel", "saved_model"); | ||||
| for (var i = 0; i < 50; i++) | |||||
| { | |||||
| var session = Session.LoadFromSavedModel(ClassifierModelPath); | |||||
| session.graph.Exit(); | |||||
| session.graph.Dispose(); | |||||
| session.Dispose(); | |||||
| for (var i = 0; i < 1024; i++) | |||||
| { | |||||
| using var sess = Session.LoadFromSavedModel(ClassifierModelPath); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -13,7 +13,9 @@ namespace TensorFlowBenchmark | |||||
| static void Main(string[] args) | static void Main(string[] args) | ||||
| { | { | ||||
| print(tf.VERSION); | print(tf.VERSION); | ||||
| /*new RepeatDataSetCrash().Run(); | |||||
| /*new SavedModelCleanup().Run(); | |||||
| new RepeatDataSetCrash().Run(); | |||||
| new GpuLeakByCNN().Run();*/ | new GpuLeakByCNN().Run();*/ | ||||
| if (args?.Length > 0) | if (args?.Length > 0) | ||||
| @@ -37,7 +37,7 @@ | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="BenchmarkDotNet" Version="0.13.0" /> | <PackageReference Include="BenchmarkDotNet" Version="0.13.0" /> | ||||
| <PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.6.0-rc0" /> | |||||
| <PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.6.0" /> | |||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||