| @@ -145,7 +145,7 @@ namespace Tensorflow.Eager | |||
| return flat_result; | |||
| } | |||
| TFE_Op GetOp(Context ctx, string op_or_function_name, Status status) | |||
| SafeOpHandle GetOp(Context ctx, string op_or_function_name, Status status) | |||
| { | |||
| if (thread_local_eager_operation_map.find(ctx, out var op)) | |||
| c_api.TFE_OpReset(op, op_or_function_name, ctx.device_name, status.Handle); | |||
| @@ -159,7 +159,7 @@ namespace Tensorflow.Eager | |||
| return op; | |||
| } | |||
| static UnorderedMap<Context, TFE_Op> thread_local_eager_operation_map = new UnorderedMap<Context, TFE_Op>(); | |||
| static UnorderedMap<Context, SafeOpHandle> thread_local_eager_operation_map = new UnorderedMap<Context, SafeOpHandle>(); | |||
| bool HasAccumulator() | |||
| { | |||
| @@ -192,7 +192,7 @@ namespace Tensorflow.Eager | |||
| ArgDef input_arg, | |||
| List<object> flattened_attrs, | |||
| List<Tensor> flattened_inputs, | |||
| IntPtr op, | |||
| SafeOpHandle op, | |||
| Status status) | |||
| { | |||
| IntPtr input_handle; | |||
| @@ -224,7 +224,7 @@ namespace Tensorflow.Eager | |||
| return true; | |||
| } | |||
| public void SetOpAttrs(TFE_Op op, params object[] attrs) | |||
| public void SetOpAttrs(SafeOpHandle op, params object[] attrs) | |||
| { | |||
| var status = tf.status; | |||
| var len = attrs.Length; | |||
| @@ -257,7 +257,7 @@ namespace Tensorflow.Eager | |||
| /// <param name="attr_value"></param> | |||
| /// <param name="attr_list_sizes"></param> | |||
| /// <param name="status"></param> | |||
| void SetOpAttrWithDefaults(Context ctx, IntPtr op, AttrDef attr, | |||
| void SetOpAttrWithDefaults(Context ctx, SafeOpHandle op, AttrDef attr, | |||
| string attr_name, object attr_value, | |||
| Dictionary<string, long> attr_list_sizes, | |||
| Status status) | |||
| @@ -290,7 +290,7 @@ namespace Tensorflow.Eager | |||
| } | |||
| } | |||
| bool SetOpAttrList(Context ctx, IntPtr op, | |||
| bool SetOpAttrList(Context ctx, SafeOpHandle op, | |||
| string key, object value, TF_AttrType type, | |||
| Dictionary<string, long> attr_list_sizes, | |||
| Status status) | |||
| @@ -298,7 +298,7 @@ namespace Tensorflow.Eager | |||
| return false; | |||
| } | |||
| bool SetOpAttrScalar(Context ctx, IntPtr op, | |||
| bool SetOpAttrScalar(Context ctx, SafeOpHandle op, | |||
| string key, object value, TF_AttrType type, | |||
| Dictionary<string, long> attr_list_sizes, | |||
| Status status) | |||
| @@ -0,0 +1,40 @@ | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow.Eager | |||
| { | |||
| public sealed class SafeOpHandle : SafeTensorflowHandle | |||
| { | |||
| private SafeOpHandle() | |||
| { | |||
| } | |||
| public SafeOpHandle(IntPtr handle) | |||
| : base(handle) | |||
| { | |||
| } | |||
| protected override bool ReleaseHandle() | |||
| { | |||
| c_api.TFE_DeleteOp(handle); | |||
| SetHandle(IntPtr.Zero); | |||
| return true; | |||
| } | |||
| } | |||
| } | |||
| @@ -1,23 +0,0 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Eager | |||
| { | |||
| public struct TFE_Op | |||
| { | |||
| IntPtr _handle; | |||
| public TFE_Op(IntPtr handle) | |||
| => _handle = handle; | |||
| public static implicit operator TFE_Op(IntPtr handle) | |||
| => new TFE_Op(handle); | |||
| public static implicit operator IntPtr(TFE_Op tensor) | |||
| => tensor._handle; | |||
| public override string ToString() | |||
| => $"TFE_Op 0x{_handle.ToString("x16")}"; | |||
| } | |||
| } | |||
| @@ -30,7 +30,7 @@ namespace Tensorflow | |||
| /// <param name="status">TF_Status*</param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern TF_AttrType TFE_OpGetAttrType(IntPtr op, string attr_name, ref byte is_list, SafeStatusHandle status); | |||
| public static extern TF_AttrType TFE_OpGetAttrType(SafeOpHandle op, string attr_name, ref byte is_list, SafeStatusHandle status); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern TF_AttrType TFE_OpNameGetAttrType(SafeContextHandle ctx, string op_or_function_name, string attr_name, ref byte is_list, SafeStatusHandle status); | |||
| @@ -43,7 +43,7 @@ namespace Tensorflow | |||
| /// <param name="input_name">const char*</param> | |||
| /// <param name="status">TF_Status*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TFE_OpGetInputLength(IntPtr op, string input_name, SafeStatusHandle status); | |||
| public static extern int TFE_OpGetInputLength(SafeOpHandle op, string input_name, SafeStatusHandle status); | |||
| /// <summary> | |||
| /// Returns the length (number of tensors) of the output argument `output_name` | |||
| @@ -54,7 +54,7 @@ namespace Tensorflow | |||
| /// <param name="status"></param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TFE_OpGetOutputLength(IntPtr op, string input_name, SafeStatusHandle status); | |||
| public static extern int TFE_OpGetOutputLength(SafeOpHandle op, string input_name, SafeStatusHandle status); | |||
| /// <summary> | |||
| /// | |||
| @@ -65,7 +65,7 @@ namespace Tensorflow | |||
| /// <param name="status">TF_Status*</param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TFE_OpAddInputList(IntPtr op, IntPtr[] inputs, int num_inputs, SafeStatusHandle status); | |||
| public static extern int TFE_OpAddInputList(SafeOpHandle op, IntPtr[] inputs, int num_inputs, SafeStatusHandle status); | |||
| /// <summary> | |||
| /// | |||
| @@ -98,7 +98,7 @@ namespace Tensorflow | |||
| /// <param name="num_retvals">int*</param> | |||
| /// <param name="status">TF_Status*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status); | |||
| public static extern void TFE_Execute(SafeOpHandle op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status); | |||
| /// <summary> | |||
| /// | |||
| @@ -108,7 +108,7 @@ namespace Tensorflow | |||
| /// <param name="status">TF_Status*</param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern TFE_Op TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status); | |||
| public static extern SafeOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status); | |||
| /// <summary> | |||
| /// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This | |||
| @@ -124,7 +124,7 @@ namespace Tensorflow | |||
| /// <param name="raw_device_name">const char*</param> | |||
| /// <param name="status">TF_Status*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFE_OpReset(IntPtr op_to_reset, string op_or_function_name, string raw_device_name, SafeStatusHandle status); | |||
| public static extern void TFE_OpReset(SafeOpHandle op_to_reset, string op_or_function_name, string raw_device_name, SafeStatusHandle status); | |||
| /// <summary> | |||
| /// | |||
| @@ -140,10 +140,10 @@ namespace Tensorflow | |||
| /// <param name="attr_name">const char*</param> | |||
| /// <param name="value">TF_DataType</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFE_OpSetAttrType(IntPtr op, string attr_name, TF_DataType value); | |||
| public static extern void TFE_OpSetAttrType(SafeOpHandle op, string attr_name, TF_DataType value); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFE_OpSetAttrInt(IntPtr op, string attr_name, long value); | |||
| public static extern void TFE_OpSetAttrInt(SafeOpHandle op, string attr_name, long value); | |||
| /// <summary> | |||
| /// | |||
| @@ -154,10 +154,10 @@ namespace Tensorflow | |||
| /// <param name="num_dims">const int</param> | |||
| /// <param name="out_status">TF_Status*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFE_OpSetAttrShape(IntPtr op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status); | |||
| public static extern void TFE_OpSetAttrShape(SafeOpHandle op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFE_OpSetAttrBool(IntPtr op, string attr_name, bool value); | |||
| public static extern void TFE_OpSetAttrBool(SafeOpHandle op, string attr_name, bool value); | |||
| /// <summary> | |||
| /// | |||
| @@ -167,7 +167,7 @@ namespace Tensorflow | |||
| /// <param name="value">const void*</param> | |||
| /// <param name="length">size_t</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length); | |||
| public static extern void TFE_OpSetAttrString(SafeOpHandle op, string attr_name, string value, uint length); | |||
| /// <summary> | |||
| /// | |||
| @@ -176,7 +176,7 @@ namespace Tensorflow | |||
| /// <param name="device_name"></param> | |||
| /// <param name="status"></param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFE_OpSetDevice(TFE_Op op, string device_name, SafeStatusHandle status); | |||
| public static extern void TFE_OpSetDevice(SafeOpHandle op, string device_name, SafeStatusHandle status); | |||
| /// <summary> | |||
| /// | |||
| @@ -185,7 +185,7 @@ namespace Tensorflow | |||
| /// <param name="h">TFE_TensorHandle*</param> | |||
| /// <param name="status">TF_Status*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFE_OpAddInput(IntPtr op, IntPtr h, SafeStatusHandle status); | |||
| public static extern void TFE_OpAddInput(SafeOpHandle op, IntPtr h, SafeStatusHandle status); | |||
| /// <summary> | |||
| /// | |||
| @@ -82,25 +82,25 @@ namespace TensorFlowNET.UnitTest | |||
| protected ulong TF_TensorByteSize(IntPtr t) | |||
| => c_api.TF_TensorByteSize(t); | |||
| protected void TFE_OpAddInput(IntPtr op, IntPtr h, SafeStatusHandle status) | |||
| protected void TFE_OpAddInput(SafeOpHandle op, IntPtr h, SafeStatusHandle status) | |||
| => c_api.TFE_OpAddInput(op, h, status); | |||
| protected void TFE_OpSetAttrType(IntPtr op, string attr_name, TF_DataType value) | |||
| protected void TFE_OpSetAttrType(SafeOpHandle op, string attr_name, TF_DataType value) | |||
| => c_api.TFE_OpSetAttrType(op, attr_name, value); | |||
| protected void TFE_OpSetAttrShape(IntPtr op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status) | |||
| protected void TFE_OpSetAttrShape(SafeOpHandle op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status) | |||
| => c_api.TFE_OpSetAttrShape(op, attr_name, dims, num_dims, out_status); | |||
| protected void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length) | |||
| protected void TFE_OpSetAttrString(SafeOpHandle op, string attr_name, string value, uint length) | |||
| => c_api.TFE_OpSetAttrString(op, attr_name, value, length); | |||
| protected IntPtr TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status) | |||
| protected SafeOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status) | |||
| => c_api.TFE_NewOp(ctx, op_or_function_name, status); | |||
| protected IntPtr TFE_NewTensorHandle(IntPtr t, SafeStatusHandle status) | |||
| => c_api.TFE_NewTensorHandle(t, status); | |||
| protected void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status) | |||
| protected void TFE_Execute(SafeOpHandle op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status) | |||
| => c_api.TFE_Execute(op, retvals, ref num_retvals, status); | |||
| protected SafeContextOptionsHandle TFE_NewContextOptions() | |||
| @@ -109,21 +109,18 @@ namespace TensorFlowNET.UnitTest | |||
| protected SafeContextHandle TFE_NewContext(SafeContextOptionsHandle opts, SafeStatusHandle status) | |||
| => c_api.TFE_NewContext(opts, status); | |||
| protected int TFE_OpGetInputLength(IntPtr op, string input_name, SafeStatusHandle status) | |||
| protected int TFE_OpGetInputLength(SafeOpHandle op, string input_name, SafeStatusHandle status) | |||
| => c_api.TFE_OpGetInputLength(op, input_name, status); | |||
| protected int TFE_OpAddInputList(IntPtr op, IntPtr[] inputs, int num_inputs, SafeStatusHandle status) | |||
| protected int TFE_OpAddInputList(SafeOpHandle op, IntPtr[] inputs, int num_inputs, SafeStatusHandle status) | |||
| => c_api.TFE_OpAddInputList(op, inputs, num_inputs, status); | |||
| protected int TFE_OpGetOutputLength(IntPtr op, string input_name, SafeStatusHandle status) | |||
| protected int TFE_OpGetOutputLength(SafeOpHandle op, string input_name, SafeStatusHandle status) | |||
| => c_api.TFE_OpGetOutputLength(op, input_name, status); | |||
| protected void TFE_DeleteTensorHandle(IntPtr h) | |||
| => c_api.TFE_DeleteTensorHandle(h); | |||
| protected void TFE_DeleteOp(IntPtr op) | |||
| => c_api.TFE_DeleteOp(op); | |||
| protected SafeExecutorHandle TFE_ContextGetExecutorForThread(SafeContextHandle ctx) | |||
| => c_api.TFE_ContextGetExecutorForThread(ctx); | |||
| @@ -154,7 +151,7 @@ namespace TensorFlowNET.UnitTest | |||
| protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status) | |||
| => c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); | |||
| protected void TFE_OpSetDevice(IntPtr op, string device_name, SafeStatusHandle status) | |||
| protected void TFE_OpSetDevice(SafeOpHandle op, string device_name, SafeStatusHandle status) | |||
| => c_api.TFE_OpSetDevice(op, device_name, status); | |||
| } | |||
| } | |||
| @@ -34,13 +34,15 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| var m = TestMatrixTensorHandle(); | |||
| var matmul = MatMulOp(ctx, m, m); | |||
| var retvals = new IntPtr[] { IntPtr.Zero, IntPtr.Zero }; | |||
| int num_retvals = 2; | |||
| c_api.TFE_Execute(matmul, retvals, ref num_retvals, status); | |||
| EXPECT_EQ(1, num_retvals); | |||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| TFE_DeleteOp(matmul); | |||
| using (var matmul = MatMulOp(ctx, m, m)) | |||
| { | |||
| int num_retvals = 2; | |||
| c_api.TFE_Execute(matmul, retvals, ref num_retvals, status); | |||
| EXPECT_EQ(1, num_retvals); | |||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| } | |||
| TFE_DeleteTensorHandle(m); | |||
| t = TFE_TensorHandleResolve(retvals[0], status); | |||
| @@ -26,37 +26,38 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
| var input1 = TestMatrixTensorHandle(); | |||
| var input2 = TestMatrixTensorHandle(); | |||
| var identityOp = TFE_NewOp(ctx, "IdentityN", status); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| var retvals = new IntPtr[2]; | |||
| using (var identityOp = TFE_NewOp(ctx, "IdentityN", status)) | |||
| { | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| // Try to retrieve lengths before building the attributes (should fail) | |||
| EXPECT_EQ(-1, TFE_OpGetInputLength(identityOp, "input", status)); | |||
| CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status)); | |||
| CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| // Try to retrieve lengths before building the attributes (should fail) | |||
| EXPECT_EQ(-1, TFE_OpGetInputLength(identityOp, "input", status)); | |||
| CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status)); | |||
| CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| var inputs = new IntPtr[] { input1, input2 }; | |||
| TFE_OpAddInputList(identityOp, inputs, 2, status); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| var inputs = new IntPtr[] { input1, input2 }; | |||
| TFE_OpAddInputList(identityOp, inputs, 2, status); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| // Try to retrieve lengths before executing the op (should work) | |||
| EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status)); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| // Try to retrieve lengths before executing the op (should work) | |||
| EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status)); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| var retvals = new IntPtr[2]; | |||
| int num_retvals = 2; | |||
| TFE_Execute(identityOp, retvals, ref num_retvals, status); | |||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| int num_retvals = 2; | |||
| TFE_Execute(identityOp, retvals, ref num_retvals, status); | |||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| // Try to retrieve lengths after executing the op (should work) | |||
| EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status)); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| // Try to retrieve lengths after executing the op (should work) | |||
| EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status)); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| } | |||
| TFE_DeleteOp(identityOp); | |||
| TFE_DeleteTensorHandle(input1); | |||
| TFE_DeleteTensorHandle(input2); | |||
| TFE_DeleteTensorHandle(retvals[0]); | |||
| @@ -27,27 +27,28 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
| var condition = TestScalarTensorHandle(true); | |||
| var t1 = TestMatrixTensorHandle(); | |||
| var t2 = TestAxisTensorHandle(); | |||
| var assertOp = TFE_NewOp(ctx, "Assert", status); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| TFE_OpAddInput(assertOp, condition, status); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| var data = new[] { condition, t1, t2 }; | |||
| TFE_OpAddInputList(assertOp, data, 3, status); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| var retvals = new IntPtr[1]; | |||
| using (var assertOp = TFE_NewOp(ctx, "Assert", status)) | |||
| { | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| TFE_OpAddInput(assertOp, condition, status); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| var data = new[] { condition, t1, t2 }; | |||
| TFE_OpAddInputList(assertOp, data, 3, status); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| /*var attr_values = Graph.TFE_GetOpDef("Assert").Attr; | |||
| var attr_found = attr_values.First(x => x.Name == "T"); | |||
| EXPECT_NE(attr_found, attr_values.Last());*/ | |||
| // EXPECT_EQ(attr_found.Type[0], "DT_BOOL"); | |||
| //EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT); | |||
| //EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32); | |||
| /*var attr_values = Graph.TFE_GetOpDef("Assert").Attr; | |||
| var attr_found = attr_values.First(x => x.Name == "T"); | |||
| EXPECT_NE(attr_found, attr_values.Last());*/ | |||
| // EXPECT_EQ(attr_found.Type[0], "DT_BOOL"); | |||
| //EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT); | |||
| //EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32); | |||
| var retvals = new IntPtr[1]; | |||
| int num_retvals = 1; | |||
| TFE_Execute(assertOp, retvals, ref num_retvals, status); | |||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| int num_retvals = 1; | |||
| TFE_Execute(assertOp, retvals, ref num_retvals, status); | |||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| } | |||
| TFE_DeleteOp(assertOp); | |||
| TFE_DeleteTensorHandle(condition); | |||
| TFE_DeleteTensorHandle(t1); | |||
| TFE_DeleteTensorHandle(t2); | |||
| @@ -40,25 +40,26 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
| var hgpu = TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_device_name, status); | |||
| ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | |||
| var shape_op = ShapeOp(ctx, hgpu); | |||
| TFE_OpSetDevice(shape_op, gpu_device_name, status); | |||
| ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | |||
| var retvals = new IntPtr[1]; | |||
| int num_retvals = 1; | |||
| c_api.TFE_Execute(shape_op, retvals, ref num_retvals, status); | |||
| ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | |||
| using (var shape_op = ShapeOp(ctx, hgpu)) | |||
| { | |||
| TFE_OpSetDevice(shape_op, gpu_device_name, status); | |||
| ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | |||
| int num_retvals = 1; | |||
| c_api.TFE_Execute(shape_op, retvals, ref num_retvals, status); | |||
| ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | |||
| // .device of shape is GPU since the op is executed on GPU | |||
| device_name = TFE_TensorHandleDeviceName(retvals[0], status); | |||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| ASSERT_TRUE(device_name.Contains("GPU:0")); | |||
| // .device of shape is GPU since the op is executed on GPU | |||
| device_name = TFE_TensorHandleDeviceName(retvals[0], status); | |||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| ASSERT_TRUE(device_name.Contains("GPU:0")); | |||
| // .backing_device of shape is CPU since the tensor is backed by CPU | |||
| backing_device_name = TFE_TensorHandleBackingDeviceName(retvals[0], status); | |||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| ASSERT_TRUE(backing_device_name.Contains("CPU:0")); | |||
| // .backing_device of shape is CPU since the tensor is backed by CPU | |||
| backing_device_name = TFE_TensorHandleBackingDeviceName(retvals[0], status); | |||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| ASSERT_TRUE(backing_device_name.Contains("CPU:0")); | |||
| } | |||
| TFE_DeleteOp(shape_op); | |||
| TFE_DeleteTensorHandle(retvals[0]); | |||
| TFE_DeleteTensorHandle(hgpu); | |||
| } | |||
| @@ -28,15 +28,16 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
| var var_handle = CreateVariable(ctx, 12.0f, status); | |||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| var op = TFE_NewOp(ctx, "ReadVariableOp", status); | |||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | |||
| TFE_OpAddInput(op, var_handle, status); | |||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| int num_retvals = 1; | |||
| var value_handle = new[] { IntPtr.Zero }; | |||
| TFE_Execute(op, value_handle, ref num_retvals, status); | |||
| TFE_DeleteOp(op); | |||
| using (var op = TFE_NewOp(ctx, "ReadVariableOp", status)) | |||
| { | |||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | |||
| TFE_OpAddInput(op, var_handle, status); | |||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| TFE_Execute(op, value_handle, ref num_retvals, status); | |||
| } | |||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| ASSERT_EQ(1, num_retvals); | |||
| @@ -26,7 +26,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
| return th; | |||
| } | |||
| IntPtr MatMulOp(SafeContextHandle ctx, IntPtr a, IntPtr b) | |||
| SafeOpHandle MatMulOp(SafeContextHandle ctx, IntPtr a, IntPtr b) | |||
| { | |||
| using var status = TF_NewStatus(); | |||
| @@ -64,7 +64,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
| return false; | |||
| } | |||
| IntPtr ShapeOp(SafeContextHandle ctx, IntPtr a) | |||
| SafeOpHandle ShapeOp(SafeContextHandle ctx, IntPtr a) | |||
| { | |||
| using var status = TF_NewStatus(); | |||
| @@ -79,39 +79,43 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
| unsafe IntPtr CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status) | |||
| { | |||
| var op = TFE_NewOp(ctx, "VarHandleOp", status); | |||
| if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||
| TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | |||
| TFE_OpSetAttrShape(op, "shape", new long[0], 0, status); | |||
| TFE_OpSetAttrString(op, "container", "", 0); | |||
| TFE_OpSetAttrString(op, "shared_name", "", 0); | |||
| if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||
| var var_handle = new IntPtr[1]; | |||
| int num_retvals = 1; | |||
| TFE_Execute(op, var_handle, ref num_retvals, status); | |||
| TFE_DeleteOp(op); | |||
| using (var op = TFE_NewOp(ctx, "VarHandleOp", status)) | |||
| { | |||
| if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||
| TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | |||
| TFE_OpSetAttrShape(op, "shape", new long[0], 0, status); | |||
| TFE_OpSetAttrString(op, "container", "", 0); | |||
| TFE_OpSetAttrString(op, "shared_name", "", 0); | |||
| if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||
| TFE_Execute(op, var_handle, ref num_retvals, status); | |||
| } | |||
| if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||
| CHECK_EQ(1, num_retvals); | |||
| // Assign 'value' to it. | |||
| op = TFE_NewOp(ctx, "AssignVariableOp", status); | |||
| if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||
| TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | |||
| TFE_OpAddInput(op, var_handle[0], status); | |||
| using (var op = TFE_NewOp(ctx, "AssignVariableOp", status)) | |||
| { | |||
| if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||
| TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | |||
| TFE_OpAddInput(op, var_handle[0], status); | |||
| // Convert 'value' to a TF_Tensor then a TFE_TensorHandle. | |||
| var t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, new long[0], 0, sizeof(float)); | |||
| tf.memcpy(TF_TensorData(t).ToPointer(), &value, TF_TensorByteSize(t)); | |||
| // Convert 'value' to a TF_Tensor then a TFE_TensorHandle. | |||
| var t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, new long[0], 0, sizeof(float)); | |||
| tf.memcpy(TF_TensorData(t).ToPointer(), &value, TF_TensorByteSize(t)); | |||
| var value_handle = c_api.TFE_NewTensorHandle(t, status); | |||
| if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||
| var value_handle = c_api.TFE_NewTensorHandle(t, status); | |||
| if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||
| TFE_OpAddInput(op, value_handle, status); | |||
| if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||
| TFE_OpAddInput(op, value_handle, status); | |||
| if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||
| num_retvals = 0; | |||
| c_api.TFE_Execute(op, null, ref num_retvals, status); | |||
| } | |||
| num_retvals = 0; | |||
| c_api.TFE_Execute(op, null, ref num_retvals, status); | |||
| TFE_DeleteOp(op); | |||
| if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||
| CHECK_EQ(0, num_retvals); | |||