| @@ -13,8 +13,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras", "src\Ten | |||||
| EndProject | EndProject | ||||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras.UnitTest", "test\Tensorflow.Keras.UnitTest\Tensorflow.Keras.UnitTest.csproj", "{EB92DD90-6346-41FB-B967-2B33A860AD98}" | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras.UnitTest", "test\Tensorflow.Keras.UnitTest\Tensorflow.Keras.UnitTest.csproj", "{EB92DD90-6346-41FB-B967-2B33A860AD98}" | ||||
| EndProject | EndProject | ||||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Hub", "src\TensorFlowNET.Hub\Tensorflow.Hub.csproj", "{95B077C1-E21B-486F-8BDD-1C902FE687AB}" | |||||
| EndProject | |||||
| Global | Global | ||||
| GlobalSection(SolutionConfigurationPlatforms) = preSolution | GlobalSection(SolutionConfigurationPlatforms) = preSolution | ||||
| Debug|Any CPU = Debug|Any CPU | Debug|Any CPU = Debug|Any CPU | ||||
| @@ -107,22 +105,6 @@ Global | |||||
| {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|Any CPU.Build.0 = Release|Any CPU | {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|Any CPU.Build.0 = Release|Any CPU | ||||
| {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x64.ActiveCfg = Release|Any CPU | {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x64.ActiveCfg = Release|Any CPU | ||||
| {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x64.Build.0 = Release|Any CPU | {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x64.Build.0 = Release|Any CPU | ||||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug|x64.ActiveCfg = Debug|Any CPU | |||||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug|x64.Build.0 = Debug|Any CPU | |||||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Publish|Any CPU.ActiveCfg = Debug|Any CPU | |||||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Publish|Any CPU.Build.0 = Debug|Any CPU | |||||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Publish|x64.ActiveCfg = Debug|Any CPU | |||||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Publish|x64.Build.0 = Debug|Any CPU | |||||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Release|Any CPU.Build.0 = Release|Any CPU | |||||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Release|x64.ActiveCfg = Release|Any CPU | |||||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Release|x64.Build.0 = Release|Any CPU | |||||
| EndGlobalSection | EndGlobalSection | ||||
| GlobalSection(SolutionProperties) = preSolution | GlobalSection(SolutionProperties) = preSolution | ||||
| HideSolutionNode = FALSE | HideSolutionNode = FALSE | ||||
| @@ -261,7 +261,7 @@ namespace Tensorflow | |||||
| public Tensor divide(Tensor a, Tensor b) | public Tensor divide(Tensor a, Tensor b) | ||||
| => gen_math_ops.real_div(a, b); | |||||
| => a / b; | |||||
| public Tensor sqrt(Tensor a, string name = null) | public Tensor sqrt(Tensor a, string name = null) | ||||
| => gen_math_ops.sqrt(a, name); | => gen_math_ops.sqrt(a, name); | ||||
| @@ -77,6 +77,19 @@ namespace Tensorflow | |||||
| Console.WriteLine(_tostring(obj)); | Console.WriteLine(_tostring(obj)); | ||||
| } | } | ||||
| public static void print(string format, params object[] objects) | |||||
| { | |||||
| if (!format.Contains("{}")) | |||||
| Console.WriteLine(format, string.Join(" ", objects.Select(x => x.ToString()))); | |||||
| foreach(var obj in objects) | |||||
| { | |||||
| } | |||||
| Console.WriteLine(format); | |||||
| } | |||||
| public static int len(object a) | public static int len(object a) | ||||
| { | { | ||||
| switch (a) | switch (a) | ||||
| @@ -37,6 +37,16 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern int TF_DeviceListCount(IntPtr list); | public static extern int TF_DeviceListCount(IntPtr list); | ||||
| /// <summary> | |||||
| /// Retrieves the type of the device at the given index. | |||||
| /// </summary> | |||||
| /// <param name="list">TF_DeviceList*</param> | |||||
| /// <param name="index">int</param> | |||||
| /// <param name="status">TF_Status*</param> | |||||
| /// <returns></returns> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern IntPtr TF_DeviceListType(IntPtr list, int index, IntPtr status); | |||||
| /// <summary> | /// <summary> | ||||
| /// Deallocates the device list. | /// Deallocates the device list. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -44,6 +54,18 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_DeleteDeviceList(IntPtr list); | public static extern void TF_DeleteDeviceList(IntPtr list); | ||||
| /// <summary> | |||||
| /// Create a new TFE_TensorHandle with the same contents as 'h' but placed | |||||
| /// in the memory of the device name 'device_name'. | |||||
| /// </summary> | |||||
| /// <param name="h">TFE_TensorHandle*</param> | |||||
| /// <param name="ctx">TFE_Context*</param> | |||||
| /// <param name="device_name">char*</param> | |||||
| /// <param name="status">TF_Status*</param> | |||||
| /// <returns>TFE_TensorHandle*</returns> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, IntPtr ctx, string device_name, IntPtr status); | |||||
| /// <summary> | /// <summary> | ||||
| /// Retrieves the full name of the device (e.g. /job:worker/replica:0/...) | /// Retrieves the full name of the device (e.g. /job:worker/replica:0/...) | ||||
| /// The return value will be a pointer to a null terminated string. The caller | /// The return value will be a pointer to a null terminated string. The caller | ||||
| @@ -54,6 +76,6 @@ namespace Tensorflow | |||||
| /// <param name="index"></param> | /// <param name="index"></param> | ||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern string TF_DeviceListName(IntPtr list, int index, IntPtr status); | |||||
| public static extern IntPtr TF_DeviceListName(IntPtr list, int index, IntPtr status); | |||||
| } | } | ||||
| } | } | ||||
| @@ -159,6 +159,28 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern int TFE_TensorHandleNumDims(IntPtr h, IntPtr status); | public static extern int TFE_TensorHandleNumDims(IntPtr h, IntPtr status); | ||||
| /// <summary> | |||||
| /// Returns the device of the operation that produced `h`. If `h` was produced by | |||||
| /// a copy, returns the destination device of the copy. Note that the returned | |||||
| /// device name is not always the device holding the tensor handle's memory. If | |||||
| /// you want the latter, use TFE_TensorHandleBackingDeviceName. This function | |||||
| /// will block till the operation that produces `h` has completed. | |||||
| /// </summary> | |||||
| /// <param name="h">TFE_TensorHandle*</param> | |||||
| /// <param name="status">TF_Status*</param> | |||||
| /// <returns></returns> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern IntPtr TFE_TensorHandleDeviceName(IntPtr h, IntPtr status); | |||||
| /// <summary> | |||||
| /// Returns the name of the device in whose memory `h` resides. | |||||
| /// </summary> | |||||
| /// <param name="h">TFE_TensorHandle*</param> | |||||
| /// <param name="status">TF_Status*</param> | |||||
| /// <returns></returns> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern IntPtr TFE_TensorHandleBackingDeviceName(IntPtr h, IntPtr status); | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| /// </summary> | /// </summary> | ||||
| @@ -14,7 +14,7 @@ namespace Tensorflow.Eager | |||||
| string device_name, | string device_name, | ||||
| string opName, | string opName, | ||||
| string name, | string name, | ||||
| params Tensor[] inputs) | |||||
| params object[] inputs) | |||||
| { | { | ||||
| IntPtr op = IntPtr.Zero; | IntPtr op = IntPtr.Zero; | ||||
| var attr_list_sizes = new Dictionary<string, int>(); | var attr_list_sizes = new Dictionary<string, int>(); | ||||
| @@ -42,7 +42,14 @@ namespace Tensorflow.Eager | |||||
| else | else | ||||
| { | { | ||||
| // The item is a single item. | // The item is a single item. | ||||
| AddInputToOp(inputs[i], true, input_arg, op, status); | |||||
| switch (inputs[i]) | |||||
| { | |||||
| case Tensor inputTensor: | |||||
| AddInputToOp(inputTensor, true, input_arg, op, status); | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -140,19 +140,14 @@ namespace Tensorflow | |||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| } | } | ||||
| public static EagerTensor add(Tensor x, Tensor y, string name = null) | |||||
| public static Tensor add<Tx, Ty>(Tx x, Ty y, string name = null) | |||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
| { | { | ||||
| var _result = pywrap_tfe_src.TFE_Py_FastPathExecute(tf.context, "", "Add", name, new[] { x, y }); | |||||
| var _result = pywrap_tfe_src.TFE_Py_FastPathExecute(tf.context, "", "Add", name, x, y); | |||||
| return _result; | return _result; | ||||
| } | } | ||||
| return null; | |||||
| } | |||||
| public static Tensor add<Tx, Ty>(Tx x, Ty y, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y }); | var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y }); | ||||
| return _op.output; | return _op.output; | ||||
| @@ -469,6 +464,12 @@ namespace Tensorflow | |||||
| public static Tensor cast(Tensor x, TF_DataType DstT, bool Truncate= false, string name= null) | public static Tensor cast(Tensor x, TF_DataType DstT, bool Truncate= false, string name= null) | ||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| var _result = pywrap_tfe_src.TFE_Py_FastPathExecute(tf.context, "", "Cast", name, x, "DstT", DstT, "Truncate", Truncate); | |||||
| return _result; | |||||
| } | |||||
| var _op = _op_def_lib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate }); | var _op = _op_def_lib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate }); | ||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| @@ -490,6 +491,12 @@ namespace Tensorflow | |||||
| public static Tensor sub<Tx, Ty>(Tx x, Ty y, string name = null) | public static Tensor sub<Tx, Ty>(Tx x, Ty y, string name = null) | ||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| var _result = pywrap_tfe_src.TFE_Py_FastPathExecute(tf.context, "", "Sub", name, x, y); | |||||
| return _result; | |||||
| } | |||||
| var _op = _op_def_lib._apply_op_helper("Sub", name, args: new { x, y }); | var _op = _op_def_lib._apply_op_helper("Sub", name, args: new { x, y }); | ||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| @@ -535,6 +542,12 @@ namespace Tensorflow | |||||
| public static Tensor mul<Tx, Ty>(Tx x, Ty y, string name = null) | public static Tensor mul<Tx, Ty>(Tx x, Ty y, string name = null) | ||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| var _result = pywrap_tfe_src.TFE_Py_FastPathExecute(tf.context, "", "Mul", name, x, y); | |||||
| return _result; | |||||
| } | |||||
| var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y }); | var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y }); | ||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| @@ -549,6 +562,12 @@ namespace Tensorflow | |||||
| public static Tensor real_div(Tensor x, Tensor y, string name = null) | public static Tensor real_div(Tensor x, Tensor y, string name = null) | ||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| var _result = pywrap_tfe_src.TFE_Py_FastPathExecute(tf.context, "", "RealDiv", name, x, y); | |||||
| return _result; | |||||
| } | |||||
| var _op = _op_def_lib._apply_op_helper("RealDiv", name, args: new { x, y }); | var _op = _op_def_lib._apply_op_helper("RealDiv", name, args: new { x, y }); | ||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| @@ -570,6 +589,12 @@ namespace Tensorflow | |||||
| public static Tensor floor_div(Tensor x, Tensor y, string name = null) | public static Tensor floor_div(Tensor x, Tensor y, string name = null) | ||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | |||||
| { | |||||
| var _result = pywrap_tfe_src.TFE_Py_FastPathExecute(tf.context, "", "FloorDiv", name, x, y); | |||||
| return _result; | |||||
| } | |||||
| var _op = _op_def_lib._apply_op_helper("FloorDiv", name, args: new { x, y }); | var _op = _op_def_lib._apply_op_helper("FloorDiv", name, args: new { x, y }); | ||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| @@ -712,6 +712,22 @@ namespace Tensorflow | |||||
| var x_dtype = x.dtype.as_base_dtype(); | var x_dtype = x.dtype.as_base_dtype(); | ||||
| var y_dtype = y.dtype.as_base_dtype(); | var y_dtype = y.dtype.as_base_dtype(); | ||||
| if (x_dtype != y_dtype) | |||||
| throw new TypeError($"x and y must have the same dtype, got {x_dtype} != {y_dtype}"); | |||||
| var dtype = x_dtype switch | |||||
| { | |||||
| TF_DataType.TF_UINT8 => TF_DataType.TF_FLOAT, | |||||
| TF_DataType.TF_INT8 => TF_DataType.TF_FLOAT, | |||||
| TF_DataType.TF_INT16 => TF_DataType.TF_FLOAT, | |||||
| TF_DataType.TF_UINT16 => TF_DataType.TF_FLOAT, | |||||
| TF_DataType.TF_INT32 => TF_DataType.TF_DOUBLE, | |||||
| TF_DataType.TF_INT64 => TF_DataType.TF_DOUBLE, | |||||
| _ => x_dtype | |||||
| }; | |||||
| x = cast(x, dtype); | |||||
| y = cast(y, dtype); | |||||
| return gen_math_ops.real_div(x, y, name: name); | return gen_math_ops.real_div(x, y, name: name); | ||||
| }); | }); | ||||
| } | } | ||||
| @@ -130,7 +130,7 @@ namespace Tensorflow | |||||
| public static Tensor operator *(double lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); | public static Tensor operator *(double lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); | ||||
| public static Tensor operator *(Tensor lhs, Complex rhs) => BinaryOpWrapper("mul", lhs, rhs); | public static Tensor operator *(Tensor lhs, Complex rhs) => BinaryOpWrapper("mul", lhs, rhs); | ||||
| public static Tensor operator *(Complex lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); | public static Tensor operator *(Complex lhs, Tensor rhs) => BinaryOpWrapper("mul", lhs, rhs); | ||||
| public static Tensor operator /(Tensor lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); | |||||
| public static Tensor operator /(Tensor lhs, Tensor rhs) => BinaryOpWrapper("truediv", lhs, rhs); | |||||
| public static Tensor operator /(Tensor lhs, NDArray rhs) => BinaryOpWrapper("div", lhs, rhs); | public static Tensor operator /(Tensor lhs, NDArray rhs) => BinaryOpWrapper("div", lhs, rhs); | ||||
| public static Tensor operator /(NDArray lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); | public static Tensor operator /(NDArray lhs, Tensor rhs) => BinaryOpWrapper("div", lhs, rhs); | ||||
| public static Tensor operator /(Tensor lhs, sbyte rhs) => BinaryOpWrapper("div", lhs, rhs); | public static Tensor operator /(Tensor lhs, sbyte rhs) => BinaryOpWrapper("div", lhs, rhs); | ||||
| @@ -97,6 +97,9 @@ namespace Tensorflow | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, TF_DataType preferred_dtype = TF_DataType.DtInvalid) | public static Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, TF_DataType preferred_dtype = TF_DataType.DtInvalid) | ||||
| { | { | ||||
| if (value is Tensor tensor) | |||||
| return tensor; | |||||
| return convert_to_tensor_v2(value, dtype, preferred_dtype, name); | return convert_to_tensor_v2(value, dtype, preferred_dtype, name); | ||||
| } | } | ||||
| @@ -103,6 +103,30 @@ namespace TensorFlowNET.UnitTest | |||||
| protected IntPtr TFE_TensorHandleResolve(IntPtr h, IntPtr status) | protected IntPtr TFE_TensorHandleResolve(IntPtr h, IntPtr status) | ||||
| => c_api.TFE_TensorHandleResolve(h, status); | => c_api.TFE_TensorHandleResolve(h, status); | ||||
| protected string TFE_TensorHandleDeviceName(IntPtr h, IntPtr status) | |||||
| => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(h, status)); | |||||
| protected string TFE_TensorHandleBackingDeviceName(IntPtr h, IntPtr status) | |||||
| => c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status)); | |||||
| protected IntPtr TFE_ContextListDevices(IntPtr ctx, IntPtr status) | |||||
| => c_api.TFE_ContextListDevices(ctx, status); | |||||
| protected int TF_DeviceListCount(IntPtr list) | |||||
| => c_api.TF_DeviceListCount(list); | |||||
| protected string TF_DeviceListType(IntPtr list, int index, IntPtr status) | |||||
| => c_api.StringPiece(c_api.TF_DeviceListType(list, index, status)); | |||||
| protected string TF_DeviceListName(IntPtr list, int index, IntPtr status) | |||||
| => c_api.StringPiece(c_api.TF_DeviceListName(list, index, status)); | |||||
| protected void TF_DeleteDeviceList(IntPtr list) | |||||
| => c_api.TF_DeleteDeviceList(list); | |||||
| protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, IntPtr ctx, string device_name, IntPtr status) | |||||
| => c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); | |||||
| protected unsafe void memcpy(void * src, IntPtr dst, ulong size) | protected unsafe void memcpy(void * src, IntPtr dst, ulong size) | ||||
| { | { | ||||
| Buffer.MemoryCopy(src, dst.ToPointer(), size, size); | Buffer.MemoryCopy(src, dst.ToPointer(), size, size); | ||||
| @@ -33,7 +33,7 @@ namespace TensorFlowNET.UnitTest.Eager | |||||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| } | } | ||||
| // c_api.TF_DeleteDeviceList(devices); | |||||
| c_api.TF_DeleteDeviceList(devices); | |||||
| c_api.TF_DeleteStatus(status); | c_api.TF_DeleteStatus(status); | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,43 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using System; | |||||
| using Tensorflow; | |||||
| using Tensorflow.Eager; | |||||
| using Buffer = System.Buffer; | |||||
| namespace TensorFlowNET.UnitTest.Eager | |||||
| { | |||||
| public partial class CApiEagerTest | |||||
| { | |||||
| /// <summary> | |||||
| /// TEST(CAPI, TensorHandleDevices) | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public unsafe void TensorHandleDevices() | |||||
| { | |||||
| var status = c_api.TF_NewStatus(); | |||||
| var opts = TFE_NewContextOptions(); | |||||
| var ctx = TFE_NewContext(opts, status); | |||||
| TFE_DeleteContextOptions(opts); | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| var hcpu = TestMatrixTensorHandle(); | |||||
| var device_name = TFE_TensorHandleDeviceName(hcpu, status); | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| ASSERT_TRUE(device_name.Contains("CPU:0")); | |||||
| var backing_device_name = TFE_TensorHandleBackingDeviceName(hcpu, status); | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| ASSERT_TRUE(backing_device_name.Contains("CPU:0")); | |||||
| // Disable the test if no GPU is present. | |||||
| string gpu_device_name = ""; | |||||
| if(GetDeviceName(ctx, ref gpu_device_name, "GPU")) | |||||
| { | |||||
| var hgpu = TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_device_name, status); | |||||
| ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | |||||
| // shape_op = ShapeOp(ctx, hgpu); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -42,5 +42,30 @@ namespace TensorFlowNET.UnitTest.Eager | |||||
| return op; | return op; | ||||
| } | } | ||||
| bool GetDeviceName(IntPtr ctx, ref string device_name, string device_type) | |||||
| { | |||||
| var status = TF_NewStatus(); | |||||
| var devices = TFE_ContextListDevices(ctx, status); | |||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| int num_devices = TF_DeviceListCount(devices); | |||||
| for (int i = 0; i < num_devices; ++i) | |||||
| { | |||||
| var dev_type = TF_DeviceListType(devices, i, status); | |||||
| CHECK_EQ(TF_GetCode(status), TF_OK, TF_Message(status)); | |||||
| var dev_name = TF_DeviceListName(devices, i, status); | |||||
| CHECK_EQ(TF_GetCode(status), TF_OK, TF_Message(status)); | |||||
| if (dev_type == device_type) | |||||
| { | |||||
| device_name = dev_name; | |||||
| TF_DeleteDeviceList(devices); | |||||
| return true; | |||||
| } | |||||
| } | |||||
| TF_DeleteDeviceList(devices); | |||||
| return false; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -33,12 +33,11 @@ | |||||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.5.0" /> | <PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.5.0" /> | ||||
| <PackageReference Include="MSTest.TestAdapter" Version="2.1.0" /> | <PackageReference Include="MSTest.TestAdapter" Version="2.1.0" /> | ||||
| <PackageReference Include="MSTest.TestFramework" Version="2.1.0" /> | <PackageReference Include="MSTest.TestFramework" Version="2.1.0" /> | ||||
| <PackageReference Include="SciSharp.TensorFlow.Redist" Version="1.15.1" /> | |||||
| <PackageReference Include="SciSharp.TensorFlow.Redist-Windows-GPU" Version="1.15.1" /> | |||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| <ProjectReference Include="..\..\src\TensorFlowNET.Core\Tensorflow.Binding.csproj" /> | <ProjectReference Include="..\..\src\TensorFlowNET.Core\Tensorflow.Binding.csproj" /> | ||||
| <ProjectReference Include="..\..\src\TensorFlowNET.Hub\Tensorflow.Hub.csproj" /> | |||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||