| @@ -67,7 +67,7 @@ namespace Tensorflow | |||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| /// <returns>TFE_TensorHandle*</returns> | /// <returns>TFE_TensorHandle*</returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status); | |||||
| public static extern SafeTensorHandleHandle TFE_TensorHandleCopyToDevice(SafeTensorHandleHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status); | |||||
| /// <summary> | /// <summary> | ||||
| /// Retrieves the full name of the device (e.g. /job:worker/replica:0/...) | /// Retrieves the full name of the device (e.g. /job:worker/replica:0/...) | ||||
| @@ -33,7 +33,7 @@ namespace Tensorflow.Eager | |||||
| { | { | ||||
| for (int i = 0; i < inputs.Length; ++i) | for (int i = 0; i < inputs.Length; ++i) | ||||
| { | { | ||||
| IntPtr tensor_handle; | |||||
| SafeTensorHandleHandle tensor_handle; | |||||
| switch (inputs[i]) | switch (inputs[i]) | ||||
| { | { | ||||
| case EagerTensor et: | case EagerTensor et: | ||||
| @@ -50,10 +50,10 @@ namespace Tensorflow.Eager | |||||
| if (status.ok() && attrs != null) | if (status.ok() && attrs != null) | ||||
| SetOpAttrs(op, attrs); | SetOpAttrs(op, attrs); | ||||
| var outputs = new IntPtr[num_outputs]; | |||||
| var outputs = new SafeTensorHandleHandle[num_outputs]; | |||||
| if (status.ok()) | if (status.ok()) | ||||
| { | { | ||||
| c_api.TFE_Execute(op, outputs, ref num_outputs, status.Handle); | |||||
| c_api.TFE_Execute(op, outputs, out num_outputs, status.Handle); | |||||
| status.Check(true); | status.Check(true); | ||||
| } | } | ||||
| return outputs.Select(x => new EagerTensor(x)).ToArray(); | return outputs.Select(x => new EagerTensor(x)).ToArray(); | ||||
| @@ -154,8 +154,8 @@ namespace Tensorflow.Eager | |||||
| num_retvals += (int)delta; | num_retvals += (int)delta; | ||||
| } | } | ||||
| var retVals = new IntPtr[num_retvals]; | |||||
| c_api.TFE_Execute(op, retVals, ref num_retvals, status.Handle); | |||||
| var retVals = new SafeTensorHandleHandle[num_retvals]; | |||||
| c_api.TFE_Execute(op, retVals, out num_retvals, status.Handle); | |||||
| status.Check(true); | status.Check(true); | ||||
| var flat_result = retVals.Select(x => new EagerTensor(x)).ToArray(); | var flat_result = retVals.Select(x => new EagerTensor(x)).ToArray(); | ||||
| @@ -220,7 +220,7 @@ namespace Tensorflow.Eager | |||||
| SafeOpHandle op, | SafeOpHandle op, | ||||
| Status status) | Status status) | ||||
| { | { | ||||
| IntPtr input_handle; | |||||
| SafeTensorHandleHandle input_handle; | |||||
| // ConvertToTensor(); | // ConvertToTensor(); | ||||
| switch (inputs) | switch (inputs) | ||||
| @@ -14,7 +14,7 @@ namespace Tensorflow.Eager | |||||
| } | } | ||||
| public EagerTensor(IntPtr handle) : base(IntPtr.Zero) | |||||
| public EagerTensor(SafeTensorHandleHandle handle) : base(IntPtr.Zero) | |||||
| { | { | ||||
| EagerTensorHandle = handle; | EagerTensorHandle = handle; | ||||
| Resolve(); | Resolve(); | ||||
| @@ -58,14 +58,20 @@ namespace Tensorflow.Eager | |||||
| } | } | ||||
| public override IntPtr ToPointer() | public override IntPtr ToPointer() | ||||
| => EagerTensorHandle; | |||||
| => EagerTensorHandle?.DangerousGetHandle() ?? IntPtr.Zero; | |||||
| protected override void DisposeManagedResources() | |||||
| { | |||||
| base.DisposeManagedResources(); | |||||
| //print($"deleting DeleteTensorHandle {Id} {EagerTensorHandle.ToString("x16")}"); | |||||
| EagerTensorHandle.Dispose(); | |||||
| } | |||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | protected override void DisposeUnmanagedResources(IntPtr handle) | ||||
| { | { | ||||
| //print($"deleting DeleteTensorHandle {Id} {_handle.ToString("x16")}"); | //print($"deleting DeleteTensorHandle {Id} {_handle.ToString("x16")}"); | ||||
| c_api.TF_DeleteTensor(_handle); | c_api.TF_DeleteTensor(_handle); | ||||
| //print($"deleting DeleteTensorHandle {Id} {EagerTensorHandle.ToString("x16")}"); | |||||
| c_api.TFE_DeleteTensorHandle(EagerTensorHandle); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -8,7 +8,8 @@ namespace Tensorflow.Eager | |||||
| { | { | ||||
| public partial class EagerTensor | public partial class EagerTensor | ||||
| { | { | ||||
| [Obsolete("Implicit conversion of EagerTensor to IntPtr is not supported.", error: true)] | |||||
| public static implicit operator IntPtr(EagerTensor tensor) | public static implicit operator IntPtr(EagerTensor tensor) | ||||
| => tensor.EagerTensorHandle; | |||||
| => throw new NotSupportedException(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,40 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| using Tensorflow.Util; | |||||
| namespace Tensorflow.Eager | |||||
| { | |||||
| public sealed class SafeTensorHandleHandle : SafeTensorflowHandle | |||||
| { | |||||
| private SafeTensorHandleHandle() | |||||
| { | |||||
| } | |||||
| public SafeTensorHandleHandle(IntPtr handle) | |||||
| : base(handle) | |||||
| { | |||||
| } | |||||
| protected override bool ReleaseHandle() | |||||
| { | |||||
| c_api.TFE_DeleteTensorHandle(handle); | |||||
| SetHandle(IntPtr.Zero); | |||||
| return true; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,19 +0,0 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Runtime.InteropServices; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Eager | |||||
| { | |||||
| [StructLayout(LayoutKind.Sequential)] | |||||
| public struct TFE_TensorHandle | |||||
| { | |||||
| IntPtr _handle; | |||||
| public static implicit operator IntPtr(TFE_TensorHandle tensor) | |||||
| => tensor._handle; | |||||
| public override string ToString() | |||||
| => $"TFE_TensorHandle 0x{_handle.ToString("x16")}"; | |||||
| } | |||||
| } | |||||
| @@ -3,6 +3,7 @@ using System; | |||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using Tensorflow.Device; | using Tensorflow.Device; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using Tensorflow.Util; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -66,7 +67,7 @@ namespace Tensorflow | |||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern int TFE_OpAddInputList(SafeOpHandle op, IntPtr[] inputs, int num_inputs, SafeStatusHandle status); | |||||
| public static extern int TFE_OpAddInputList(SafeOpHandle op, [In, MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(SafeHandleArrayMarshaler))] SafeTensorHandleHandle[] inputs, int num_inputs, SafeStatusHandle status); | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| @@ -90,6 +91,20 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFE_DeleteContext(IntPtr ctx); | public static extern void TFE_DeleteContext(IntPtr ctx); | ||||
| public static void TFE_Execute(SafeOpHandle op, SafeTensorHandleHandle[] retvals, out int num_retvals, SafeStatusHandle status) | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| num_retvals = retvals?.Length ?? 0; | |||||
| var rawReturns = stackalloc IntPtr[num_retvals]; | |||||
| TFE_Execute(op, rawReturns, ref num_retvals, status); | |||||
| for (var i = 0; i < num_retvals; i++) | |||||
| { | |||||
| retvals[i] = new SafeTensorHandleHandle(rawReturns[i]); | |||||
| } | |||||
| } | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Execute the operation defined by 'op' and return handles to computed | /// Execute the operation defined by 'op' and return handles to computed | ||||
| /// tensors in `retvals`. | /// tensors in `retvals`. | ||||
| @@ -99,7 +114,7 @@ namespace Tensorflow | |||||
| /// <param name="num_retvals">int*</param> | /// <param name="num_retvals">int*</param> | ||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFE_Execute(SafeOpHandle op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status); | |||||
| private static unsafe extern void TFE_Execute(SafeOpHandle op, IntPtr* retvals, ref int num_retvals, SafeStatusHandle status); | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| @@ -198,7 +213,7 @@ namespace Tensorflow | |||||
| /// <param name="h">TFE_TensorHandle*</param> | /// <param name="h">TFE_TensorHandle*</param> | ||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFE_OpAddInput(SafeOpHandle op, IntPtr h, SafeStatusHandle status); | |||||
| public static extern void TFE_OpAddInput(SafeOpHandle op, SafeTensorHandleHandle h, SafeStatusHandle status); | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| @@ -206,10 +221,10 @@ namespace Tensorflow | |||||
| /// <param name="t">const tensorflow::Tensor&</param> | /// <param name="t">const tensorflow::Tensor&</param> | ||||
| /// <returns>TFE_TensorHandle*</returns> | /// <returns>TFE_TensorHandle*</returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern TFE_TensorHandle TFE_NewTensorHandle(IntPtr t, SafeStatusHandle status); | |||||
| public static extern SafeTensorHandleHandle TFE_NewTensorHandle(IntPtr t, SafeStatusHandle status); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TFE_EagerTensorHandle(IntPtr t); | |||||
| public static extern SafeTensorHandleHandle TFE_EagerTensorHandle(IntPtr t); | |||||
| /// <summary> | /// <summary> | ||||
| /// Sets the default execution mode (sync/async). Note that this can be | /// Sets the default execution mode (sync/async). Note that this can be | ||||
| @@ -226,7 +241,7 @@ namespace Tensorflow | |||||
| /// <param name="h">TFE_TensorHandle*</param> | /// <param name="h">TFE_TensorHandle*</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern TF_DataType TFE_TensorHandleDataType(IntPtr h); | |||||
| public static extern TF_DataType TFE_TensorHandleDataType(SafeTensorHandleHandle h); | |||||
| /// <summary> | /// <summary> | ||||
| /// This function will block till the operation that produces `h` has | /// This function will block till the operation that produces `h` has | ||||
| @@ -237,7 +252,7 @@ namespace Tensorflow | |||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TFE_TensorHandleResolve(IntPtr h, SafeStatusHandle status); | |||||
| public static extern IntPtr TFE_TensorHandleResolve(SafeTensorHandleHandle h, SafeStatusHandle status); | |||||
| /// <summary> | /// <summary> | ||||
| @@ -247,10 +262,10 @@ namespace Tensorflow | |||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern int TFE_TensorHandleNumDims(IntPtr h, SafeStatusHandle status); | |||||
| public static extern int TFE_TensorHandleNumDims(SafeTensorHandleHandle h, SafeStatusHandle status); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern int TFE_TensorHandleDim(IntPtr h, int dim, SafeStatusHandle status); | |||||
| public static extern int TFE_TensorHandleDim(SafeTensorHandleHandle h, int dim, SafeStatusHandle status); | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the device of the operation that produced `h`. If `h` was produced by | /// Returns the device of the operation that produced `h`. If `h` was produced by | ||||
| @@ -263,7 +278,7 @@ namespace Tensorflow | |||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TFE_TensorHandleDeviceName(IntPtr h, SafeStatusHandle status); | |||||
| public static extern IntPtr TFE_TensorHandleDeviceName(SafeTensorHandleHandle h, SafeStatusHandle status); | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the name of the device in whose memory `h` resides. | /// Returns the name of the device in whose memory `h` resides. | ||||
| @@ -272,7 +287,7 @@ namespace Tensorflow | |||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TFE_TensorHandleBackingDeviceName(IntPtr h, SafeStatusHandle status); | |||||
| public static extern IntPtr TFE_TensorHandleBackingDeviceName(SafeTensorHandleHandle h, SafeStatusHandle status); | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| @@ -12,7 +12,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| public class EagerTensorV2 : DisposableObject, ITensor | public class EagerTensorV2 : DisposableObject, ITensor | ||||
| { | { | ||||
| IntPtr EagerTensorHandle; | |||||
| SafeTensorHandleHandle EagerTensorHandle; | |||||
| public string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(EagerTensorHandle, tf.status.Handle)); | public string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(EagerTensorHandle, tf.status.Handle)); | ||||
| public EagerTensorV2(IntPtr handle) | public EagerTensorV2(IntPtr handle) | ||||
| @@ -64,10 +64,14 @@ namespace Tensorflow | |||||
| } | } | ||||
| }*/ | }*/ | ||||
| protected override void DisposeManagedResources() | |||||
| { | |||||
| EagerTensorHandle.Dispose(); | |||||
| } | |||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | protected override void DisposeUnmanagedResources(IntPtr handle) | ||||
| { | { | ||||
| c_api.TF_DeleteTensor(_handle); | c_api.TF_DeleteTensor(_handle); | ||||
| c_api.TFE_DeleteTensorHandle(EagerTensorHandle); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -23,6 +23,7 @@ using System.Linq; | |||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using System.Text; | using System.Text; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using Tensorflow.Eager; | |||||
| using Tensorflow.Framework; | using Tensorflow.Framework; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -94,7 +95,7 @@ namespace Tensorflow | |||||
| /// <summary> | /// <summary> | ||||
| /// TFE_TensorHandle | /// TFE_TensorHandle | ||||
| /// </summary> | /// </summary> | ||||
| public IntPtr EagerTensorHandle { get; set; } | |||||
| public SafeTensorHandleHandle EagerTensorHandle { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the shape of a tensor. | /// Returns the shape of a tensor. | ||||
| @@ -0,0 +1,132 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| using System.Runtime.ExceptionServices; | |||||
| using System.Runtime.InteropServices; | |||||
| namespace Tensorflow.Util | |||||
| { | |||||
| internal sealed class SafeHandleArrayMarshaler : ICustomMarshaler | |||||
| { | |||||
| private static readonly SafeHandleArrayMarshaler Instance = new SafeHandleArrayMarshaler(); | |||||
| private SafeHandleArrayMarshaler() | |||||
| { | |||||
| } | |||||
| #pragma warning disable IDE0060 // Remove unused parameter (method is used implicitly) | |||||
| public static ICustomMarshaler GetInstance(string cookie) | |||||
| #pragma warning restore IDE0060 // Remove unused parameter | |||||
| { | |||||
| return Instance; | |||||
| } | |||||
| public int GetNativeDataSize() | |||||
| { | |||||
| return IntPtr.Size; | |||||
| } | |||||
| [HandleProcessCorruptedStateExceptions] | |||||
| public IntPtr MarshalManagedToNative(object ManagedObj) | |||||
| { | |||||
| if (ManagedObj is null) | |||||
| return IntPtr.Zero; | |||||
| var array = (SafeHandle[])ManagedObj; | |||||
| var native = IntPtr.Zero; | |||||
| var marshaledArrayHandle = false; | |||||
| try | |||||
| { | |||||
| native = Marshal.AllocHGlobal((array.Length + 1) * IntPtr.Size); | |||||
| Marshal.WriteIntPtr(native, GCHandle.ToIntPtr(GCHandle.Alloc(array))); | |||||
| marshaledArrayHandle = true; | |||||
| var i = 0; | |||||
| var success = false; | |||||
| try | |||||
| { | |||||
| for (i = 0; i < array.Length; i++) | |||||
| { | |||||
| success = false; | |||||
| var current = array[i]; | |||||
| var currentHandle = IntPtr.Zero; | |||||
| if (current is object) | |||||
| { | |||||
| current.DangerousAddRef(ref success); | |||||
| currentHandle = current.DangerousGetHandle(); | |||||
| } | |||||
| Marshal.WriteIntPtr(native, ofs: (i + 1) * IntPtr.Size, currentHandle); | |||||
| } | |||||
| return IntPtr.Add(native, IntPtr.Size); | |||||
| } | |||||
| catch | |||||
| { | |||||
| // Clean up any handles which were leased prior to the exception | |||||
| var total = success ? i + 1 : i; | |||||
| for (var j = 0; j < total; j++) | |||||
| { | |||||
| var current = array[i]; | |||||
| if (current is object) | |||||
| current.DangerousRelease(); | |||||
| } | |||||
| throw; | |||||
| } | |||||
| } | |||||
| catch | |||||
| { | |||||
| if (native != IntPtr.Zero) | |||||
| { | |||||
| if (marshaledArrayHandle) | |||||
| GCHandle.FromIntPtr(Marshal.ReadIntPtr(native)).Free(); | |||||
| Marshal.FreeHGlobal(native); | |||||
| } | |||||
| throw; | |||||
| } | |||||
| } | |||||
| public void CleanUpNativeData(IntPtr pNativeData) | |||||
| { | |||||
| if (pNativeData == IntPtr.Zero) | |||||
| return; | |||||
| var managedHandle = GCHandle.FromIntPtr(Marshal.ReadIntPtr(pNativeData, -IntPtr.Size)); | |||||
| var array = (SafeHandle[])managedHandle.Target; | |||||
| managedHandle.Free(); | |||||
| for (var i = 0; i < array.Length; i++) | |||||
| { | |||||
| if (array[i] is object && !array[i].IsClosed) | |||||
| array[i].DangerousRelease(); | |||||
| } | |||||
| } | |||||
| public object MarshalNativeToManaged(IntPtr pNativeData) | |||||
| { | |||||
| throw new NotSupportedException(); | |||||
| } | |||||
| public void CleanUpManagedData(object ManagedObj) | |||||
| { | |||||
| throw new NotSupportedException(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -56,10 +56,10 @@ namespace TensorFlowNET.UnitTest | |||||
| protected void TF_SetAttrBool(OperationDescription desc, string attrName, bool value) | protected void TF_SetAttrBool(OperationDescription desc, string attrName, bool value) | ||||
| => c_api.TF_SetAttrBool(desc, attrName, value); | => c_api.TF_SetAttrBool(desc, attrName, value); | ||||
| protected TF_DataType TFE_TensorHandleDataType(IntPtr h) | |||||
| protected TF_DataType TFE_TensorHandleDataType(SafeTensorHandleHandle h) | |||||
| => c_api.TFE_TensorHandleDataType(h); | => c_api.TFE_TensorHandleDataType(h); | ||||
| protected int TFE_TensorHandleNumDims(IntPtr h, SafeStatusHandle status) | |||||
| protected int TFE_TensorHandleNumDims(SafeTensorHandleHandle h, SafeStatusHandle status) | |||||
| => c_api.TFE_TensorHandleNumDims(h, status); | => c_api.TFE_TensorHandleNumDims(h, status); | ||||
| protected TF_Code TF_GetCode(Status s) | protected TF_Code TF_GetCode(Status s) | ||||
| @@ -83,7 +83,7 @@ namespace TensorFlowNET.UnitTest | |||||
| protected ulong TF_TensorByteSize(IntPtr t) | protected ulong TF_TensorByteSize(IntPtr t) | ||||
| => c_api.TF_TensorByteSize(t); | => c_api.TF_TensorByteSize(t); | ||||
| protected void TFE_OpAddInput(SafeOpHandle op, IntPtr h, SafeStatusHandle status) | |||||
| protected void TFE_OpAddInput(SafeOpHandle op, SafeTensorHandleHandle h, SafeStatusHandle status) | |||||
| => c_api.TFE_OpAddInput(op, h, status); | => c_api.TFE_OpAddInput(op, h, status); | ||||
| protected void TFE_OpSetAttrType(SafeOpHandle op, string attr_name, TF_DataType value) | protected void TFE_OpSetAttrType(SafeOpHandle op, string attr_name, TF_DataType value) | ||||
| @@ -98,11 +98,11 @@ namespace TensorFlowNET.UnitTest | |||||
| protected SafeOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status) | protected SafeOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status) | ||||
| => c_api.TFE_NewOp(ctx, op_or_function_name, status); | => c_api.TFE_NewOp(ctx, op_or_function_name, status); | ||||
| protected IntPtr TFE_NewTensorHandle(IntPtr t, SafeStatusHandle status) | |||||
| protected SafeTensorHandleHandle TFE_NewTensorHandle(IntPtr t, SafeStatusHandle status) | |||||
| => c_api.TFE_NewTensorHandle(t, status); | => c_api.TFE_NewTensorHandle(t, status); | ||||
| protected void TFE_Execute(SafeOpHandle op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status) | |||||
| => c_api.TFE_Execute(op, retvals, ref num_retvals, status); | |||||
| protected void TFE_Execute(SafeOpHandle op, SafeTensorHandleHandle[] retvals, out int num_retvals, SafeStatusHandle status) | |||||
| => c_api.TFE_Execute(op, retvals, out num_retvals, status); | |||||
| protected SafeContextOptionsHandle TFE_NewContextOptions() | protected SafeContextOptionsHandle TFE_NewContextOptions() | ||||
| => c_api.TFE_NewContextOptions(); | => c_api.TFE_NewContextOptions(); | ||||
| @@ -113,7 +113,7 @@ namespace TensorFlowNET.UnitTest | |||||
| protected int TFE_OpGetInputLength(SafeOpHandle op, string input_name, SafeStatusHandle status) | protected int TFE_OpGetInputLength(SafeOpHandle op, string input_name, SafeStatusHandle status) | ||||
| => c_api.TFE_OpGetInputLength(op, input_name, status); | => c_api.TFE_OpGetInputLength(op, input_name, status); | ||||
| protected int TFE_OpAddInputList(SafeOpHandle op, IntPtr[] inputs, int num_inputs, SafeStatusHandle status) | |||||
| protected int TFE_OpAddInputList(SafeOpHandle op, SafeTensorHandleHandle[] inputs, int num_inputs, SafeStatusHandle status) | |||||
| => c_api.TFE_OpAddInputList(op, inputs, num_inputs, status); | => c_api.TFE_OpAddInputList(op, inputs, num_inputs, status); | ||||
| protected int TFE_OpGetOutputLength(SafeOpHandle op, string input_name, SafeStatusHandle status) | protected int TFE_OpGetOutputLength(SafeOpHandle op, string input_name, SafeStatusHandle status) | ||||
| @@ -128,13 +128,13 @@ namespace TensorFlowNET.UnitTest | |||||
| protected void TFE_ExecutorWaitForAllPendingNodes(SafeExecutorHandle executor, SafeStatusHandle status) | protected void TFE_ExecutorWaitForAllPendingNodes(SafeExecutorHandle executor, SafeStatusHandle status) | ||||
| => c_api.TFE_ExecutorWaitForAllPendingNodes(executor, status); | => c_api.TFE_ExecutorWaitForAllPendingNodes(executor, status); | ||||
| protected IntPtr TFE_TensorHandleResolve(IntPtr h, SafeStatusHandle status) | |||||
| protected IntPtr TFE_TensorHandleResolve(SafeTensorHandleHandle h, SafeStatusHandle status) | |||||
| => c_api.TFE_TensorHandleResolve(h, status); | => c_api.TFE_TensorHandleResolve(h, status); | ||||
| protected string TFE_TensorHandleDeviceName(IntPtr h, SafeStatusHandle status) | |||||
| protected string TFE_TensorHandleDeviceName(SafeTensorHandleHandle h, SafeStatusHandle status) | |||||
| => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(h, status)); | => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(h, status)); | ||||
| protected string TFE_TensorHandleBackingDeviceName(IntPtr h, SafeStatusHandle status) | |||||
| protected string TFE_TensorHandleBackingDeviceName(SafeTensorHandleHandle h, SafeStatusHandle status) | |||||
| => c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status)); | => c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status)); | ||||
| protected SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status) | protected SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status) | ||||
| @@ -149,7 +149,7 @@ namespace TensorFlowNET.UnitTest | |||||
| protected string TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status) | protected string TF_DeviceListName(SafeDeviceListHandle list, int index, SafeStatusHandle status) | ||||
| => c_api.TF_DeviceListName(list, index, status); | => c_api.TF_DeviceListName(list, index, status); | ||||
| protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status) | |||||
| protected SafeTensorHandleHandle TFE_TensorHandleCopyToDevice(SafeTensorHandleHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status) | |||||
| => c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); | => c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); | ||||
| protected void TFE_OpSetDevice(SafeOpHandle op, string device_name, SafeStatusHandle status) | protected void TFE_OpSetDevice(SafeOpHandle op, string device_name, SafeStatusHandle status) | ||||
| @@ -33,21 +33,25 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| { | { | ||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| var m = TestMatrixTensorHandle(); | |||||
| var retvals = new IntPtr[] { IntPtr.Zero, IntPtr.Zero }; | |||||
| using (var matmul = MatMulOp(ctx, m, m)) | |||||
| var retvals = new SafeTensorHandleHandle[2]; | |||||
| try | |||||
| { | { | ||||
| int num_retvals = 2; | |||||
| c_api.TFE_Execute(matmul, retvals, ref num_retvals, status); | |||||
| EXPECT_EQ(1, num_retvals); | |||||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| } | |||||
| TFE_DeleteTensorHandle(m); | |||||
| using (var m = TestMatrixTensorHandle()) | |||||
| using (var matmul = MatMulOp(ctx, m, m)) | |||||
| { | |||||
| int num_retvals; | |||||
| c_api.TFE_Execute(matmul, retvals, out num_retvals, status); | |||||
| EXPECT_EQ(1, num_retvals); | |||||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| } | |||||
| t = TFE_TensorHandleResolve(retvals[0], status); | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| TFE_DeleteTensorHandle(retvals[0]); | |||||
| t = TFE_TensorHandleResolve(retvals[0], status); | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| } | |||||
| finally | |||||
| { | |||||
| retvals[0]?.Dispose(); | |||||
| } | |||||
| } | } | ||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| @@ -24,9 +24,10 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| using var ctx = NewContext(status); | using var ctx = NewContext(status); | ||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| var input1 = TestMatrixTensorHandle(); | |||||
| var input2 = TestMatrixTensorHandle(); | |||||
| var retvals = new IntPtr[2]; | |||||
| using var input1 = TestMatrixTensorHandle(); | |||||
| using var input2 = TestMatrixTensorHandle(); | |||||
| var retvals = new SafeTensorHandleHandle[2]; | |||||
| using (var identityOp = TFE_NewOp(ctx, "IdentityN", status)) | using (var identityOp = TFE_NewOp(ctx, "IdentityN", status)) | ||||
| { | { | ||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| @@ -37,7 +38,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status)); | EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status)); | ||||
| CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| var inputs = new IntPtr[] { input1, input2 }; | |||||
| var inputs = new SafeTensorHandleHandle[] { input1, input2 }; | |||||
| TFE_OpAddInputList(identityOp, inputs, 2, status); | TFE_OpAddInputList(identityOp, inputs, 2, status); | ||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| @@ -47,21 +48,24 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); | EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); | ||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| int num_retvals = 2; | |||||
| TFE_Execute(identityOp, retvals, ref num_retvals, status); | |||||
| int num_retvals; | |||||
| TFE_Execute(identityOp, retvals, out num_retvals, status); | |||||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| // Try to retrieve lengths after executing the op (should work) | |||||
| EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status)); | |||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); | |||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| try | |||||
| { | |||||
| // Try to retrieve lengths after executing the op (should work) | |||||
| EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status)); | |||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); | |||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| } | |||||
| finally | |||||
| { | |||||
| retvals[0]?.Dispose(); | |||||
| retvals[1]?.Dispose(); | |||||
| } | |||||
| } | } | ||||
| TFE_DeleteTensorHandle(input1); | |||||
| TFE_DeleteTensorHandle(input2); | |||||
| TFE_DeleteTensorHandle(retvals[0]); | |||||
| TFE_DeleteTensorHandle(retvals[1]); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,5 +1,4 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using System; | |||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| @@ -24,10 +23,9 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| using var ctx = NewContext(status); | using var ctx = NewContext(status); | ||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| var condition = TestScalarTensorHandle(true); | |||||
| var t1 = TestMatrixTensorHandle(); | |||||
| var t2 = TestAxisTensorHandle(); | |||||
| var retvals = new IntPtr[1]; | |||||
| using var condition = TestScalarTensorHandle(true); | |||||
| using var t1 = TestMatrixTensorHandle(); | |||||
| using var t2 = TestAxisTensorHandle(); | |||||
| using (var assertOp = TFE_NewOp(ctx, "Assert", status)) | using (var assertOp = TFE_NewOp(ctx, "Assert", status)) | ||||
| { | { | ||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| @@ -44,15 +42,13 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| //EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT); | //EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT); | ||||
| //EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32); | //EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32); | ||||
| int num_retvals = 1; | |||||
| TFE_Execute(assertOp, retvals, ref num_retvals, status); | |||||
| var retvals = new SafeTensorHandleHandle[1]; | |||||
| int num_retvals; | |||||
| TFE_Execute(assertOp, retvals, out num_retvals, status); | |||||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| } | |||||
| TFE_DeleteTensorHandle(condition); | |||||
| TFE_DeleteTensorHandle(t1); | |||||
| TFE_DeleteTensorHandle(t2); | |||||
| TFE_DeleteTensorHandle(retvals[0]); | |||||
| retvals[0]?.Dispose(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -13,7 +13,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| [TestMethod] | [TestMethod] | ||||
| public unsafe void TensorHandle() | public unsafe void TensorHandle() | ||||
| { | { | ||||
| var h = TestMatrixTensorHandle(); | |||||
| using var h = TestMatrixTensorHandle(); | |||||
| EXPECT_EQ(TF_FLOAT, c_api.TFE_TensorHandleDataType(h)); | EXPECT_EQ(TF_FLOAT, c_api.TFE_TensorHandleDataType(h)); | ||||
| var status = c_api.TF_NewStatus(); | var status = c_api.TF_NewStatus(); | ||||
| @@ -28,7 +28,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| EXPECT_EQ(3.0f, data[2]); | EXPECT_EQ(3.0f, data[2]); | ||||
| EXPECT_EQ(4.0f, data[3]); | EXPECT_EQ(4.0f, data[3]); | ||||
| c_api.TF_DeleteTensor(t); | c_api.TF_DeleteTensor(t); | ||||
| c_api.TFE_DeleteTensorHandle(h); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -24,47 +24,52 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| using var ctx = NewContext(status); | using var ctx = NewContext(status); | ||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| var hcpu = TestMatrixTensorHandle(); | |||||
| var device_name = TFE_TensorHandleDeviceName(hcpu, status); | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| ASSERT_TRUE(device_name.Contains("CPU:0")); | |||||
| var backing_device_name = TFE_TensorHandleBackingDeviceName(hcpu, status); | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| ASSERT_TRUE(backing_device_name.Contains("CPU:0")); | |||||
| // Disable the test if no GPU is present. | |||||
| string gpu_device_name = ""; | |||||
| if(GetDeviceName(ctx, ref gpu_device_name, "GPU")) | |||||
| using (var hcpu = TestMatrixTensorHandle()) | |||||
| { | { | ||||
| var hgpu = TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_device_name, status); | |||||
| ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | |||||
| var device_name = TFE_TensorHandleDeviceName(hcpu, status); | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| ASSERT_TRUE(device_name.Contains("CPU:0")); | |||||
| var retvals = new IntPtr[1]; | |||||
| using (var shape_op = ShapeOp(ctx, hgpu)) | |||||
| var backing_device_name = TFE_TensorHandleBackingDeviceName(hcpu, status); | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| ASSERT_TRUE(backing_device_name.Contains("CPU:0")); | |||||
| // Disable the test if no GPU is present. | |||||
| string gpu_device_name = ""; | |||||
| if (GetDeviceName(ctx, ref gpu_device_name, "GPU")) | |||||
| { | { | ||||
| TFE_OpSetDevice(shape_op, gpu_device_name, status); | |||||
| ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | |||||
| int num_retvals = 1; | |||||
| c_api.TFE_Execute(shape_op, retvals, ref num_retvals, status); | |||||
| using var hgpu = TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_device_name, status); | |||||
| ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | ||||
| // .device of shape is GPU since the op is executed on GPU | |||||
| device_name = TFE_TensorHandleDeviceName(retvals[0], status); | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| ASSERT_TRUE(device_name.Contains("GPU:0")); | |||||
| var retvals = new SafeTensorHandleHandle[1]; | |||||
| using (var shape_op = ShapeOp(ctx, hgpu)) | |||||
| { | |||||
| TFE_OpSetDevice(shape_op, gpu_device_name, status); | |||||
| ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | |||||
| int num_retvals; | |||||
| c_api.TFE_Execute(shape_op, retvals, out num_retvals, status); | |||||
| ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | |||||
| // .backing_device of shape is CPU since the tensor is backed by CPU | |||||
| backing_device_name = TFE_TensorHandleBackingDeviceName(retvals[0], status); | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| ASSERT_TRUE(backing_device_name.Contains("CPU:0")); | |||||
| } | |||||
| try | |||||
| { | |||||
| // .device of shape is GPU since the op is executed on GPU | |||||
| device_name = TFE_TensorHandleDeviceName(retvals[0], status); | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| ASSERT_TRUE(device_name.Contains("GPU:0")); | |||||
| TFE_DeleteTensorHandle(retvals[0]); | |||||
| TFE_DeleteTensorHandle(hgpu); | |||||
| // .backing_device of shape is CPU since the tensor is backed by CPU | |||||
| backing_device_name = TFE_TensorHandleBackingDeviceName(retvals[0], status); | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| ASSERT_TRUE(backing_device_name.Contains("CPU:0")); | |||||
| } | |||||
| finally | |||||
| { | |||||
| retvals[0]?.Dispose(); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | } | ||||
| TFE_DeleteTensorHandle(hcpu); | |||||
| // not export api | // not export api | ||||
| using var executor = TFE_ContextGetExecutorForThread(ctx); | using var executor = TFE_ContextGetExecutorForThread(ctx); | ||||
| TFE_ExecutorWaitForAllPendingNodes(executor, status); | TFE_ExecutorWaitForAllPendingNodes(executor, status); | ||||
| @@ -25,35 +25,42 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| using var ctx = NewContext(status); | using var ctx = NewContext(status); | ||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| var var_handle = CreateVariable(ctx, 12.0f, status); | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| int num_retvals = 1; | |||||
| var value_handle = new[] { IntPtr.Zero }; | |||||
| using (var op = TFE_NewOp(ctx, "ReadVariableOp", status)) | |||||
| using (var var_handle = CreateVariable(ctx, 12.0f, status)) | |||||
| { | { | ||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | |||||
| TFE_OpAddInput(op, var_handle, status); | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| TFE_Execute(op, value_handle, ref num_retvals, status); | |||||
| } | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| ASSERT_EQ(1, num_retvals); | |||||
| EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(value_handle[0])); | |||||
| EXPECT_EQ(0, TFE_TensorHandleNumDims(value_handle[0], status)); | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| var value = 0f; // new float[1]; | |||||
| var t = TFE_TensorHandleResolve(value_handle[0], status); | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| ASSERT_EQ(sizeof(float), (int)TF_TensorByteSize(t)); | |||||
| tf.memcpy(&value, TF_TensorData(t).ToPointer(), sizeof(float)); | |||||
| c_api.TF_DeleteTensor(t); | |||||
| EXPECT_EQ(12.0f, value); | |||||
| int num_retvals = 1; | |||||
| var value_handle = new SafeTensorHandleHandle[1]; | |||||
| using (var op = TFE_NewOp(ctx, "ReadVariableOp", status)) | |||||
| { | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | |||||
| TFE_OpAddInput(op, var_handle, status); | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| TFE_Execute(op, value_handle, out num_retvals, status); | |||||
| } | |||||
| try | |||||
| { | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| ASSERT_EQ(1, num_retvals); | |||||
| EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(value_handle[0])); | |||||
| EXPECT_EQ(0, TFE_TensorHandleNumDims(value_handle[0], status)); | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| var value = 0f; // new float[1]; | |||||
| var t = TFE_TensorHandleResolve(value_handle[0], status); | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| ASSERT_EQ(sizeof(float), (int)TF_TensorByteSize(t)); | |||||
| tf.memcpy(&value, TF_TensorData(t).ToPointer(), sizeof(float)); | |||||
| c_api.TF_DeleteTensor(t); | |||||
| EXPECT_EQ(12.0f, value); | |||||
| } | |||||
| finally | |||||
| { | |||||
| value_handle[0]?.Dispose(); | |||||
| } | |||||
| } | |||||
| TFE_DeleteTensorHandle(var_handle); | |||||
| TFE_DeleteTensorHandle(value_handle[0]); | |||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| } | } | ||||
| } | } | ||||
| @@ -12,7 +12,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| [TestClass] | [TestClass] | ||||
| public partial class CApiEagerTest : CApiTest | public partial class CApiEagerTest : CApiTest | ||||
| { | { | ||||
| IntPtr TestMatrixTensorHandle() | |||||
| SafeTensorHandleHandle TestMatrixTensorHandle() | |||||
| { | { | ||||
| var dims = new long[] { 2, 2 }; | var dims = new long[] { 2, 2 }; | ||||
| var data = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; | var data = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; | ||||
| @@ -26,7 +26,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| return th; | return th; | ||||
| } | } | ||||
| SafeOpHandle MatMulOp(SafeContextHandle ctx, IntPtr a, IntPtr b) | |||||
| SafeOpHandle MatMulOp(SafeContextHandle ctx, SafeTensorHandleHandle a, SafeTensorHandleHandle b) | |||||
| { | { | ||||
| using var status = TF_NewStatus(); | using var status = TF_NewStatus(); | ||||
| @@ -64,7 +64,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| return false; | return false; | ||||
| } | } | ||||
| SafeOpHandle ShapeOp(SafeContextHandle ctx, IntPtr a) | |||||
| SafeOpHandle ShapeOp(SafeContextHandle ctx, SafeTensorHandleHandle a) | |||||
| { | { | ||||
| using var status = TF_NewStatus(); | using var status = TF_NewStatus(); | ||||
| @@ -77,28 +77,28 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| return op; | return op; | ||||
| } | } | ||||
| unsafe IntPtr CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status) | |||||
| unsafe SafeTensorHandleHandle CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status) | |||||
| { | { | ||||
| var var_handle = new IntPtr[1]; | |||||
| int num_retvals = 1; | |||||
| var var_handle = new SafeTensorHandleHandle[1]; | |||||
| int num_retvals; | |||||
| using (var op = TFE_NewOp(ctx, "VarHandleOp", status)) | using (var op = TFE_NewOp(ctx, "VarHandleOp", status)) | ||||
| { | { | ||||
| if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||||
| if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | |||||
| TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | ||||
| TFE_OpSetAttrShape(op, "shape", new long[0], 0, status); | TFE_OpSetAttrShape(op, "shape", new long[0], 0, status); | ||||
| TFE_OpSetAttrString(op, "container", "", 0); | TFE_OpSetAttrString(op, "container", "", 0); | ||||
| TFE_OpSetAttrString(op, "shared_name", "", 0); | TFE_OpSetAttrString(op, "shared_name", "", 0); | ||||
| if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||||
| TFE_Execute(op, var_handle, ref num_retvals, status); | |||||
| if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | |||||
| TFE_Execute(op, var_handle, out num_retvals, status); | |||||
| } | } | ||||
| if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||||
| if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | |||||
| CHECK_EQ(1, num_retvals); | CHECK_EQ(1, num_retvals); | ||||
| // Assign 'value' to it. | // Assign 'value' to it. | ||||
| using (var op = TFE_NewOp(ctx, "AssignVariableOp", status)) | using (var op = TFE_NewOp(ctx, "AssignVariableOp", status)) | ||||
| { | { | ||||
| if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||||
| if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | |||||
| TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | ||||
| TFE_OpAddInput(op, var_handle[0], status); | TFE_OpAddInput(op, var_handle[0], status); | ||||
| @@ -107,22 +107,22 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| tf.memcpy(TF_TensorData(t).ToPointer(), &value, TF_TensorByteSize(t)); | tf.memcpy(TF_TensorData(t).ToPointer(), &value, TF_TensorByteSize(t)); | ||||
| var value_handle = c_api.TFE_NewTensorHandle(t, status); | var value_handle = c_api.TFE_NewTensorHandle(t, status); | ||||
| if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||||
| if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | |||||
| TFE_OpAddInput(op, value_handle, status); | TFE_OpAddInput(op, value_handle, status); | ||||
| if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||||
| if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | |||||
| num_retvals = 0; | num_retvals = 0; | ||||
| c_api.TFE_Execute(op, null, ref num_retvals, status); | |||||
| c_api.TFE_Execute(op, null, out num_retvals, status); | |||||
| } | } | ||||
| if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||||
| if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | |||||
| CHECK_EQ(0, num_retvals); | CHECK_EQ(0, num_retvals); | ||||
| return var_handle[0]; | return var_handle[0]; | ||||
| } | } | ||||
| IntPtr TestAxisTensorHandle() | |||||
| SafeTensorHandleHandle TestAxisTensorHandle() | |||||
| { | { | ||||
| var dims = new long[] { 1 }; | var dims = new long[] { 1 }; | ||||
| var data = new int[] { 1 }; | var data = new int[] { 1 }; | ||||
| @@ -135,7 +135,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| return th; | return th; | ||||
| } | } | ||||
| IntPtr TestScalarTensorHandle(bool value) | |||||
| SafeTensorHandleHandle TestScalarTensorHandle(bool value) | |||||
| { | { | ||||
| var data = new[] { value }; | var data = new[] { value }; | ||||
| var t = c_api.TF_AllocateTensor(TF_BOOL, null, 0, sizeof(bool)); | var t = c_api.TF_AllocateTensor(TF_BOOL, null, 0, sizeof(bool)); | ||||
| @@ -147,7 +147,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| return th; | return th; | ||||
| } | } | ||||
| IntPtr TestScalarTensorHandle(float value) | |||||
| SafeTensorHandleHandle TestScalarTensorHandle(float value) | |||||
| { | { | ||||
| var data = new [] { value }; | var data = new [] { value }; | ||||
| var t = c_api.TF_AllocateTensor(TF_FLOAT, null, 0, sizeof(float)); | var t = c_api.TF_AllocateTensor(TF_FLOAT, null, 0, sizeof(float)); | ||||