| @@ -0,0 +1,40 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| using Tensorflow.Util; | |||||
| namespace Tensorflow.Device | |||||
| { | |||||
| public sealed class SafeDeviceListHandle : SafeTensorflowHandle | |||||
| { | |||||
| private SafeDeviceListHandle() | |||||
| { | |||||
| } | |||||
| public SafeDeviceListHandle(IntPtr handle) | |||||
| : base(handle) | |||||
| { | |||||
| } | |||||
| protected override bool ReleaseHandle() | |||||
| { | |||||
| c_api.TF_DeleteDeviceList(handle); | |||||
| SetHandle(IntPtr.Zero); | |||||
| return true; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -16,6 +16,7 @@ | |||||
| using System; | using System; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using Tensorflow.Device; | |||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -36,7 +37,7 @@ namespace Tensorflow | |||||
| /// <param name="list">TF_DeviceList*</param> | /// <param name="list">TF_DeviceList*</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern int TF_DeviceListCount(IntPtr list); | |||||
| public static extern int TF_DeviceListCount(SafeDeviceListHandle list); | |||||
| /// <summary> | /// <summary> | ||||
| /// Retrieves the type of the device at the given index. | /// Retrieves the type of the device at the given index. | ||||
| @@ -46,7 +47,7 @@ namespace Tensorflow | |||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TF_DeviceListType(IntPtr list, int index, SafeStatusHandle status); | |||||
| public static extern IntPtr TF_DeviceListType(SafeDeviceListHandle list, int index, SafeStatusHandle status); | |||||
| /// <summary> | /// <summary> | ||||
| /// Deallocates the device list. | /// Deallocates the device list. | ||||
| @@ -77,6 +78,6 @@ namespace Tensorflow | |||||
| /// <param name="index"></param> | /// <param name="index"></param> | ||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TF_DeviceListName(IntPtr list, int index, SafeStatusHandle status); | |||||
| public static extern IntPtr TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status); | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,5 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using Tensorflow.Device; | |||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using TFE_Executor = System.IntPtr; | using TFE_Executor = System.IntPtr; | ||||
| @@ -317,7 +318,7 @@ namespace Tensorflow | |||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status); | |||||
| public static extern SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status); | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| @@ -1,6 +1,7 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using System; | using System; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Device; | |||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using Buffer = System.Buffer; | using Buffer = System.Buffer; | ||||
| @@ -8,14 +9,14 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| public class CApiTest | public class CApiTest | ||||
| { | { | ||||
| protected TF_Code TF_OK = TF_Code.TF_OK; | |||||
| protected TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT; | |||||
| protected TF_DataType TF_BOOL = TF_DataType.TF_BOOL; | |||||
| protected static readonly TF_Code TF_OK = TF_Code.TF_OK; | |||||
| protected static readonly TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT; | |||||
| protected static readonly TF_DataType TF_BOOL = TF_DataType.TF_BOOL; | |||||
| protected void EXPECT_TRUE(bool expected, string msg = "") | protected void EXPECT_TRUE(bool expected, string msg = "") | ||||
| => Assert.IsTrue(expected, msg); | => Assert.IsTrue(expected, msg); | ||||
| protected void EXPECT_EQ(object expected, object actual, string msg = "") | |||||
| protected static void EXPECT_EQ(object expected, object actual, string msg = "") | |||||
| => Assert.AreEqual(expected, actual, msg); | => Assert.AreEqual(expected, actual, msg); | ||||
| protected void CHECK_EQ(object expected, object actual, string msg = "") | protected void CHECK_EQ(object expected, object actual, string msg = "") | ||||
| @@ -63,10 +64,10 @@ namespace TensorFlowNET.UnitTest | |||||
| protected TF_Code TF_GetCode(Status s) | protected TF_Code TF_GetCode(Status s) | ||||
| => s.Code; | => s.Code; | ||||
| protected TF_Code TF_GetCode(SafeStatusHandle s) | |||||
| protected static TF_Code TF_GetCode(SafeStatusHandle s) | |||||
| => c_api.TF_GetCode(s); | => c_api.TF_GetCode(s); | ||||
| protected string TF_Message(SafeStatusHandle s) | |||||
| protected static string TF_Message(SafeStatusHandle s) | |||||
| => c_api.StringPiece(c_api.TF_Message(s)); | => c_api.StringPiece(c_api.TF_Message(s)); | ||||
| protected SafeStatusHandle TF_NewStatus() | protected SafeStatusHandle TF_NewStatus() | ||||
| @@ -141,21 +142,18 @@ namespace TensorFlowNET.UnitTest | |||||
| protected string TFE_TensorHandleBackingDeviceName(IntPtr h, SafeStatusHandle status) | protected string TFE_TensorHandleBackingDeviceName(IntPtr h, SafeStatusHandle status) | ||||
| => c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status)); | => c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status)); | ||||
| protected IntPtr TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status) | |||||
| protected SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status) | |||||
| => c_api.TFE_ContextListDevices(ctx, status); | => c_api.TFE_ContextListDevices(ctx, status); | ||||
| protected int TF_DeviceListCount(IntPtr list) | |||||
| protected int TF_DeviceListCount(SafeDeviceListHandle list) | |||||
| => c_api.TF_DeviceListCount(list); | => c_api.TF_DeviceListCount(list); | ||||
| protected string TF_DeviceListType(IntPtr list, int index, SafeStatusHandle status) | |||||
| protected string TF_DeviceListType(SafeDeviceListHandle list, int index, SafeStatusHandle status) | |||||
| => c_api.StringPiece(c_api.TF_DeviceListType(list, index, status)); | => c_api.StringPiece(c_api.TF_DeviceListType(list, index, status)); | ||||
| protected string TF_DeviceListName(IntPtr list, int index, SafeStatusHandle status) | |||||
| protected string TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status) | |||||
| => c_api.StringPiece(c_api.TF_DeviceListName(list, index, status)); | => c_api.StringPiece(c_api.TF_DeviceListName(list, index, status)); | ||||
| protected void TF_DeleteDeviceList(IntPtr list) | |||||
| => c_api.TF_DeleteDeviceList(list); | |||||
| protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status) | protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status) | ||||
| => c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); | => c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); | ||||
| @@ -1,6 +1,6 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using System; | |||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Device; | |||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| namespace TensorFlowNET.UnitTest.NativeAPI | namespace TensorFlowNET.UnitTest.NativeAPI | ||||
| @@ -21,13 +21,15 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| return c_api.TFE_NewContext(opts, status); | return c_api.TFE_NewContext(opts, status); | ||||
| } | } | ||||
| IntPtr devices; | |||||
| using (var ctx = NewContext(status)) | |||||
| static SafeDeviceListHandle ListDevices(SafeStatusHandle status) | |||||
| { | { | ||||
| devices = c_api.TFE_ContextListDevices(ctx, status); | |||||
| using var ctx = NewContext(status); | |||||
| var devices = c_api.TFE_ContextListDevices(ctx, status); | |||||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| return devices; | |||||
| } | } | ||||
| using var devices = ListDevices(status); | |||||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| int num_devices = c_api.TF_DeviceListCount(devices); | int num_devices = c_api.TF_DeviceListCount(devices); | ||||
| @@ -37,8 +39,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| EXPECT_NE("", c_api.TF_DeviceListName(devices, i, status), TF_Message(status)); | EXPECT_NE("", c_api.TF_DeviceListName(devices, i, status), TF_Message(status)); | ||||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| } | } | ||||
| c_api.TF_DeleteDeviceList(devices); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -43,8 +43,8 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| bool GetDeviceName(SafeContextHandle ctx, ref string device_name, string device_type) | bool GetDeviceName(SafeContextHandle ctx, ref string device_name, string device_type) | ||||
| { | { | ||||
| var status = TF_NewStatus(); | |||||
| var devices = TFE_ContextListDevices(ctx, status); | |||||
| using var status = TF_NewStatus(); | |||||
| using var devices = TFE_ContextListDevices(ctx, status); | |||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| int num_devices = TF_DeviceListCount(devices); | int num_devices = TF_DeviceListCount(devices); | ||||
| @@ -57,12 +57,10 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| if (dev_type == device_type) | if (dev_type == device_type) | ||||
| { | { | ||||
| device_name = dev_name; | device_name = dev_name; | ||||
| TF_DeleteDeviceList(devices); | |||||
| return true; | return true; | ||||
| } | } | ||||
| } | } | ||||
| TF_DeleteDeviceList(devices); | |||||
| return false; | return false; | ||||
| } | } | ||||