| @@ -91,6 +91,18 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFE_DeleteContext(IntPtr ctx); | public static extern void TFE_DeleteContext(IntPtr ctx); | ||||
| /// <summary> | |||||
| /// Execute the operation defined by <paramref name="op"/> and return handles to computed | |||||
| /// tensors in <paramref name="retvals"/>. | |||||
| /// </summary> | |||||
| /// <remarks> | |||||
| /// Upon successful return, the first <paramref name="num_retvals"/> slots in <paramref name="retvals"/> will | |||||
| /// contain handle instances which the caller is responsible for disposing once they are no longer in use. | |||||
| /// </remarks> | |||||
| /// <param name="op"></param> | |||||
| /// <param name="retvals"></param> | |||||
| /// <param name="num_retvals"></param> | |||||
| /// <param name="status"></param> | |||||
| public static void TFE_Execute(SafeOpHandle op, SafeTensorHandleHandle[] retvals, out int num_retvals, SafeStatusHandle status) | public static void TFE_Execute(SafeOpHandle op, SafeTensorHandleHandle[] retvals, out int num_retvals, SafeStatusHandle status) | ||||
| { | { | ||||
| unsafe | unsafe | ||||
| @@ -100,6 +112,9 @@ namespace Tensorflow | |||||
| TFE_Execute(op, rawReturns, ref num_retvals, status); | TFE_Execute(op, rawReturns, ref num_retvals, status); | ||||
| for (var i = 0; i < num_retvals; i++) | for (var i = 0; i < num_retvals; i++) | ||||
| { | { | ||||
| // A handle is created for every return, even if rawReturns[i] is null. The resulting handle will be | |||||
| // non-null but invalid, which is the same behavior P/Invoke gives for non-array SafeHandle return | |||||
| // values. | |||||
| retvals[i] = new SafeTensorHandleHandle(rawReturns[i]); | retvals[i] = new SafeTensorHandleHandle(rawReturns[i]); | ||||
| } | } | ||||
| } | } | ||||
| @@ -34,23 +34,23 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| var retvals = new SafeTensorHandleHandle[2]; | var retvals = new SafeTensorHandleHandle[2]; | ||||
| try | |||||
| using (var m = TestMatrixTensorHandle()) | |||||
| using (var matmul = MatMulOp(ctx, m, m)) | |||||
| { | { | ||||
| using (var m = TestMatrixTensorHandle()) | |||||
| using (var matmul = MatMulOp(ctx, m, m)) | |||||
| { | |||||
| int num_retvals; | |||||
| c_api.TFE_Execute(matmul, retvals, out num_retvals, status); | |||||
| EXPECT_EQ(1, num_retvals); | |||||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| } | |||||
| int num_retvals; | |||||
| c_api.TFE_Execute(matmul, retvals, out num_retvals, status); | |||||
| EXPECT_EQ(1, num_retvals); | |||||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| } | |||||
| try | |||||
| { | |||||
| t = TFE_TensorHandleResolve(retvals[0], status); | t = TFE_TensorHandleResolve(retvals[0], status); | ||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| } | } | ||||
| finally | finally | ||||
| { | { | ||||
| retvals[0]?.Dispose(); | |||||
| retvals[0].Dispose(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -51,6 +51,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| int num_retvals; | int num_retvals; | ||||
| TFE_Execute(identityOp, retvals, out num_retvals, status); | TFE_Execute(identityOp, retvals, out num_retvals, status); | ||||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| EXPECT_EQ(2, num_retvals); | |||||
| try | try | ||||
| { | { | ||||
| @@ -62,8 +63,8 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| } | } | ||||
| finally | finally | ||||
| { | { | ||||
| retvals[0]?.Dispose(); | |||||
| retvals[1]?.Dispose(); | |||||
| retvals[0].Dispose(); | |||||
| retvals[1].Dispose(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -46,8 +46,9 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| int num_retvals; | int num_retvals; | ||||
| TFE_Execute(assertOp, retvals, out num_retvals, status); | TFE_Execute(assertOp, retvals, out num_retvals, status); | ||||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| EXPECT_EQ(1, num_retvals); | |||||
| retvals[0]?.Dispose(); | |||||
| retvals[0].Dispose(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -49,6 +49,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| int num_retvals; | int num_retvals; | ||||
| c_api.TFE_Execute(shape_op, retvals, out num_retvals, status); | c_api.TFE_Execute(shape_op, retvals, out num_retvals, status); | ||||
| ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); | ||||
| ASSERT_EQ(1, num_retvals); | |||||
| try | try | ||||
| { | { | ||||
| @@ -64,7 +65,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| } | } | ||||
| finally | finally | ||||
| { | { | ||||
| retvals[0]?.Dispose(); | |||||
| retvals[0].Dispose(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -38,6 +38,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| TFE_OpAddInput(op, var_handle, status); | TFE_OpAddInput(op, var_handle, status); | ||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| TFE_Execute(op, value_handle, out num_retvals, status); | TFE_Execute(op, value_handle, out num_retvals, status); | ||||
| ASSERT_EQ(1, num_retvals); | |||||
| } | } | ||||
| try | try | ||||
| @@ -57,7 +58,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| } | } | ||||
| finally | finally | ||||
| { | { | ||||
| value_handle[0]?.Dispose(); | |||||
| value_handle[0].Dispose(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -90,11 +90,10 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| TFE_OpSetAttrString(op, "shared_name", "", 0); | TFE_OpSetAttrString(op, "shared_name", "", 0); | ||||
| if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | ||||
| TFE_Execute(op, var_handle, out num_retvals, status); | TFE_Execute(op, var_handle, out num_retvals, status); | ||||
| if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | |||||
| CHECK_EQ(1, num_retvals); | |||||
| } | } | ||||
| if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | |||||
| CHECK_EQ(1, num_retvals); | |||||
| // Assign 'value' to it. | // Assign 'value' to it. | ||||
| using (var op = TFE_NewOp(ctx, "AssignVariableOp", status)) | using (var op = TFE_NewOp(ctx, "AssignVariableOp", status)) | ||||
| { | { | ||||
| @@ -112,13 +111,11 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| TFE_OpAddInput(op, value_handle, status); | TFE_OpAddInput(op, value_handle, status); | ||||
| if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | ||||
| num_retvals = 0; | |||||
| c_api.TFE_Execute(op, null, out num_retvals, status); | c_api.TFE_Execute(op, null, out num_retvals, status); | ||||
| if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | |||||
| CHECK_EQ(0, num_retvals); | |||||
| } | } | ||||
| if (TF_GetCode(status) != TF_OK) return new SafeTensorHandleHandle(IntPtr.Zero); | |||||
| CHECK_EQ(0, num_retvals); | |||||
| return var_handle[0]; | return var_handle[0]; | ||||
| } | } | ||||