diff --git a/src/TensorFlowNET.Core/Eager/Context.cs b/src/TensorFlowNET.Core/Eager/Context.cs
index 3d9c875d..f32790d8 100644
--- a/src/TensorFlowNET.Core/Eager/Context.cs
+++ b/src/TensorFlowNET.Core/Eager/Context.cs
@@ -4,13 +4,29 @@ using System.Text;
namespace Tensorflow.Eager
{
- public class Context
+ public class Context : IDisposable
{
+ private IntPtr _handle;
+
public static int GRAPH_MODE = 0;
public static int EAGER_MODE = 1;
public int default_execution_mode;
+ public Context(ContextOptions opts, Status status)
+ {
+ _handle = c_api.TFE_NewContext(opts, status);
+ status.Check(true);
+ }
+
+ public void Dispose()
+ {
+ c_api.TFE_DeleteContext(_handle);
+ }
+ public static implicit operator IntPtr(Context ctx)
+ {
+ return ctx._handle;
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Eager/ContextOptions.cs b/src/TensorFlowNET.Core/Eager/ContextOptions.cs
index 8bc49c8d..46d40bf6 100644
--- a/src/TensorFlowNET.Core/Eager/ContextOptions.cs
+++ b/src/TensorFlowNET.Core/Eager/ContextOptions.cs
@@ -18,9 +18,9 @@ namespace Tensorflow.Eager
c_api.TFE_DeleteContextOptions(_handle);
}
- public static implicit operator IntPtr(ContextOptions ctx)
+ public static implicit operator IntPtr(ContextOptions opts)
{
- return ctx._handle;
+ return opts._handle;
}
}
}
diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs
index b8a88cf5..a6d9c8cd 100644
--- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs
+++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs
@@ -10,24 +10,106 @@ namespace Tensorflow
///
/// Return a new options object.
///
- ///
+ /// TFE_ContextOptions*
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_NewContextOptions();
///
/// Destroy an options object.
///
- ///
+ /// TFE_ContextOptions*
[DllImport(TensorFlowLibName)]
public static extern void TFE_DeleteContextOptions(IntPtr options);
+ ///
+ ///
+ ///
+ /// const TFE_ContextOptions*
+ /// TF_Status*
+ /// TFE_Context*
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_NewContext(IntPtr opts, IntPtr status);
+ ///
+ ///
+ ///
+ /// TFE_Context*
[DllImport(TensorFlowLibName)]
public static extern void TFE_DeleteContext(IntPtr ctx);
+ ///
+ /// Execute the operation defined by 'op' and return handles to computed
+ /// tensors in `retvals`.
+ ///
+ /// TFE_Op*
+ /// TFE_TensorHandle**
+ /// int*
+ /// TF_Status*
+ [DllImport(TensorFlowLibName)]
+ public static extern void TFE_Execute(IntPtr op, IntPtr retvals, int[] num_retvals, IntPtr status);
+
+ ///
+ ///
+ ///
+ /// TFE_Context*
+ /// const char*
+ /// TF_Status*
+ ///
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_NewOp(IntPtr ctx, string op_or_function_name, IntPtr status);
+
+ ///
+ ///
+ ///
+ /// TFE_Op*
+ [DllImport(TensorFlowLibName)]
+ public static extern void TFE_DeleteOp(IntPtr op);
+
+ ///
+ ///
+ ///
+ /// TFE_Op*
+ /// const char*
+ /// TF_DataType
+ [DllImport(TensorFlowLibName)]
+ public static extern void TFE_OpSetAttrType(IntPtr op, string attr_name, TF_DataType value);
+
+ ///
+ ///
+ ///
+ /// TFE_Op*
+ /// const char*
+ /// const int64_t*
+ /// const int
+ /// TF_Status*
+ [DllImport(TensorFlowLibName)]
+ public static extern void TFE_OpSetAttrShape(IntPtr op, string attr_name, long[] dims, int num_dims, Status out_status);
+
+ ///
+ ///
+ ///
+ /// TFE_Op*
+ /// const char*
+ /// const void*
+ /// size_t
+ [DllImport(TensorFlowLibName)]
+ public static extern void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length);
+
+ ///
+ ///
+ ///
+ /// TFE_Op*
+ /// TFE_TensorHandle*
+ /// TF_Status*
+ [DllImport(TensorFlowLibName)]
+ public static extern void TFE_OpAddInput(IntPtr op, IntPtr h, IntPtr status);
+
+ ///
+ ///
+ ///
+ /// const tensorflow::Tensor&
+ /// TFE_TensorHandle*
+ [DllImport(TensorFlowLibName)]
+ public static extern IntPtr TFE_NewTensorHandle(IntPtr t);
}
}
diff --git a/src/TensorFlowNET.Core/ops.name_scope.cs b/src/TensorFlowNET.Core/ops.name_scope.cs
index 26b5e114..403d8897 100644
--- a/src/TensorFlowNET.Core/ops.name_scope.cs
+++ b/src/TensorFlowNET.Core/ops.name_scope.cs
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
+using Tensorflow.Eager;
namespace Tensorflow
{
@@ -20,7 +21,7 @@ namespace Tensorflow
_name = name;
_default_name = default_name;
_values = values;
- _ctx = new Context();
+ // _ctx = new Context();
}
public string __enter__()
diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs
index 3e21d929..5d63a411 100644
--- a/src/TensorFlowNET.Core/tf.cs
+++ b/src/TensorFlowNET.Core/tf.cs
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
+using Tensorflow.Eager;
namespace Tensorflow
{
@@ -12,7 +13,7 @@ namespace Tensorflow
public static TF_DataType float64 = TF_DataType.TF_DOUBLE;
public static TF_DataType chars = TF_DataType.TF_STRING;
- public static Context context = new Context();
+ public static Context context;
public static Graph g = new Graph(c_api.TF_NewGraph());
@@ -28,6 +29,7 @@ namespace Tensorflow
public static void enable_eager_execution()
{
+ // contex = new Context();
context.default_execution_mode = Context.EAGER_MODE;
}
diff --git a/test/TensorFlowNET.UnitTest/Eager/CApiVariableTest.cs b/test/TensorFlowNET.UnitTest/Eager/CApiVariableTest.cs
index 12830db6..1a5cb1a5 100644
--- a/test/TensorFlowNET.UnitTest/Eager/CApiVariableTest.cs
+++ b/test/TensorFlowNET.UnitTest/Eager/CApiVariableTest.cs
@@ -3,6 +3,7 @@ using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;
+using Tensorflow.Eager;
namespace TensorFlowNET.UnitTest.Eager
{
@@ -13,16 +14,67 @@ namespace TensorFlowNET.UnitTest.Eager
public class CApiVariableTest : CApiTest, IDisposable
{
Status status = new Status();
+ ContextOptions opts = new ContextOptions();
+ Context ctx;
[TestMethod]
public void Variables()
{
+ ctx = new Context(opts, status);
+ ASSERT_EQ(TF_Code.TF_OK, status.Code);
+ opts.Dispose();
+ var var_handle = CreateVariable(ctx, 12.0F);
+ ASSERT_EQ(TF_OK, TF_GetCode(status));
+ }
+
+ private IntPtr CreateVariable(Context ctx, float value)
+ {
+ // Create the variable handle.
+ var op = c_api.TFE_NewOp(ctx, "VarHandleOp", status);
+ if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
+
+ c_api.TFE_OpSetAttrType(op, "dtype", TF_DataType.TF_FLOAT);
+ c_api.TFE_OpSetAttrShape(op, "shape", new long[0], 0, status);
+ c_api.TFE_OpSetAttrString(op, "container", "", 0);
+ c_api.TFE_OpSetAttrString(op, "shared_name", "", 0);
+ if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
+ var var_handle = IntPtr.Zero;
+ int[] num_retvals = { 1 };
+ c_api.TFE_Execute(op, var_handle, num_retvals, status);
+ c_api.TFE_DeleteOp(op);
+ if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
+ ASSERT_EQ(1, num_retvals);
+
+ // Assign 'value' to it.
+ op = c_api.TFE_NewOp(ctx, "AssignVariableOp", status);
+ if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
+ c_api.TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
+ c_api.TFE_OpAddInput(op, var_handle, status);
+
+ // Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
+ var t = new Tensor(value);
+
+ var value_handle = c_api.TFE_NewTensorHandle(t);
+ if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
+
+ c_api.TFE_OpAddInput(op, value_handle, status);
+ if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
+
+ num_retvals = new int[] { 0 };
+ c_api.TFE_Execute(op, IntPtr.Zero, num_retvals, status);
+ c_api.TFE_DeleteOp(op);
+ if (TF_GetCode(status) != TF_OK) return IntPtr.Zero;
+ ASSERT_EQ(0, num_retvals);
+
+ return var_handle;
}
public void Dispose()
{
-
+ status.Dispose();
+ opts.Dispose();
+ ctx.Dispose();
}
}
}