| @@ -0,0 +1,40 @@ | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow.Device | |||
| { | |||
| public sealed class SafeDeviceListHandle : SafeTensorflowHandle | |||
| { | |||
| private SafeDeviceListHandle() | |||
| { | |||
| } | |||
| public SafeDeviceListHandle(IntPtr handle) | |||
| : base(handle) | |||
| { | |||
| } | |||
| protected override bool ReleaseHandle() | |||
| { | |||
| c_api.TF_DeleteDeviceList(handle); | |||
| SetHandle(IntPtr.Zero); | |||
| return true; | |||
| } | |||
| } | |||
| } | |||
| @@ -16,6 +16,7 @@ | |||
| using System; | |||
| using System.Runtime.InteropServices; | |||
| using Tensorflow.Device; | |||
| using Tensorflow.Eager; | |||
| namespace Tensorflow | |||
| @@ -36,7 +37,7 @@ namespace Tensorflow | |||
| /// <param name="list">TF_DeviceList*</param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TF_DeviceListCount(IntPtr list); | |||
| public static extern int TF_DeviceListCount(SafeDeviceListHandle list); | |||
| /// <summary> | |||
| /// Retrieves the type of the device at the given index. | |||
| @@ -46,7 +47,7 @@ namespace Tensorflow | |||
| /// <param name="status">TF_Status*</param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TF_DeviceListType(IntPtr list, int index, SafeStatusHandle status); | |||
| public static extern IntPtr TF_DeviceListType(SafeDeviceListHandle list, int index, SafeStatusHandle status); | |||
| /// <summary> | |||
| /// Deallocates the device list. | |||
| @@ -77,6 +78,6 @@ namespace Tensorflow | |||
| /// <param name="index"></param> | |||
| /// <param name="status">TF_Status*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TF_DeviceListName(IntPtr list, int index, SafeStatusHandle status); | |||
| public static extern IntPtr TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status); | |||
| } | |||
| } | |||
| @@ -1,5 +1,6 @@ | |||
| using System; | |||
| using System.Runtime.InteropServices; | |||
| using Tensorflow.Device; | |||
| using Tensorflow.Eager; | |||
| using TFE_Executor = System.IntPtr; | |||
| @@ -317,7 +318,7 @@ namespace Tensorflow | |||
| /// <param name="status">TF_Status*</param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status); | |||
| public static extern SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status); | |||
| /// <summary> | |||
| /// | |||
| @@ -1,6 +1,7 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using Tensorflow; | |||
| using Tensorflow.Device; | |||
| using Tensorflow.Eager; | |||
| using Buffer = System.Buffer; | |||
| @@ -8,14 +9,14 @@ namespace TensorFlowNET.UnitTest | |||
| { | |||
| public class CApiTest | |||
| { | |||
| protected TF_Code TF_OK = TF_Code.TF_OK; | |||
| protected TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT; | |||
| protected TF_DataType TF_BOOL = TF_DataType.TF_BOOL; | |||
| protected static readonly TF_Code TF_OK = TF_Code.TF_OK; | |||
| protected static readonly TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT; | |||
| protected static readonly TF_DataType TF_BOOL = TF_DataType.TF_BOOL; | |||
| protected void EXPECT_TRUE(bool expected, string msg = "") | |||
| => Assert.IsTrue(expected, msg); | |||
| protected void EXPECT_EQ(object expected, object actual, string msg = "") | |||
| protected static void EXPECT_EQ(object expected, object actual, string msg = "") | |||
| => Assert.AreEqual(expected, actual, msg); | |||
| protected void CHECK_EQ(object expected, object actual, string msg = "") | |||
| @@ -63,10 +64,10 @@ namespace TensorFlowNET.UnitTest | |||
| protected TF_Code TF_GetCode(Status s) | |||
| => s.Code; | |||
| protected TF_Code TF_GetCode(SafeStatusHandle s) | |||
| protected static TF_Code TF_GetCode(SafeStatusHandle s) | |||
| => c_api.TF_GetCode(s); | |||
| protected string TF_Message(SafeStatusHandle s) | |||
| protected static string TF_Message(SafeStatusHandle s) | |||
| => c_api.StringPiece(c_api.TF_Message(s)); | |||
| protected SafeStatusHandle TF_NewStatus() | |||
| @@ -141,21 +142,18 @@ namespace TensorFlowNET.UnitTest | |||
| protected string TFE_TensorHandleBackingDeviceName(IntPtr h, SafeStatusHandle status) | |||
| => c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status)); | |||
| protected IntPtr TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status) | |||
| protected SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status) | |||
| => c_api.TFE_ContextListDevices(ctx, status); | |||
| protected int TF_DeviceListCount(IntPtr list) | |||
| protected int TF_DeviceListCount(SafeDeviceListHandle list) | |||
| => c_api.TF_DeviceListCount(list); | |||
| protected string TF_DeviceListType(IntPtr list, int index, SafeStatusHandle status) | |||
| protected string TF_DeviceListType(SafeDeviceListHandle list, int index, SafeStatusHandle status) | |||
| => c_api.StringPiece(c_api.TF_DeviceListType(list, index, status)); | |||
| protected string TF_DeviceListName(IntPtr list, int index, SafeStatusHandle status) | |||
| protected string TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status) | |||
| => c_api.StringPiece(c_api.TF_DeviceListName(list, index, status)); | |||
| protected void TF_DeleteDeviceList(IntPtr list) | |||
| => c_api.TF_DeleteDeviceList(list); | |||
| protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status) | |||
| => c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); | |||
| @@ -1,6 +1,6 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using Tensorflow; | |||
| using Tensorflow.Device; | |||
| using Tensorflow.Eager; | |||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||
| @@ -21,13 +21,15 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
| return c_api.TFE_NewContext(opts, status); | |||
| } | |||
| IntPtr devices; | |||
| using (var ctx = NewContext(status)) | |||
| static SafeDeviceListHandle ListDevices(SafeStatusHandle status) | |||
| { | |||
| devices = c_api.TFE_ContextListDevices(ctx, status); | |||
| using var ctx = NewContext(status); | |||
| var devices = c_api.TFE_ContextListDevices(ctx, status); | |||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| return devices; | |||
| } | |||
| using var devices = ListDevices(status); | |||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| int num_devices = c_api.TF_DeviceListCount(devices); | |||
| @@ -37,8 +39,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
| EXPECT_NE("", c_api.TF_DeviceListName(devices, i, status), TF_Message(status)); | |||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| } | |||
| c_api.TF_DeleteDeviceList(devices); | |||
| } | |||
| } | |||
| } | |||
| @@ -43,8 +43,8 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
| bool GetDeviceName(SafeContextHandle ctx, ref string device_name, string device_type) | |||
| { | |||
| var status = TF_NewStatus(); | |||
| var devices = TFE_ContextListDevices(ctx, status); | |||
| using var status = TF_NewStatus(); | |||
| using var devices = TFE_ContextListDevices(ctx, status); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| int num_devices = TF_DeviceListCount(devices); | |||
| @@ -57,12 +57,10 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
| if (dev_type == device_type) | |||
| { | |||
| device_name = dev_name; | |||
| TF_DeleteDeviceList(devices); | |||
| return true; | |||
| } | |||
| } | |||
| TF_DeleteDeviceList(devices); | |||
| return false; | |||
| } | |||