diff --git a/src/TensorFlowNET.Core/Eager/Context.cs b/src/TensorFlowNET.Core/Eager/Context.cs index 3d9c875d..f32790d8 100644 --- a/src/TensorFlowNET.Core/Eager/Context.cs +++ b/src/TensorFlowNET.Core/Eager/Context.cs @@ -4,13 +4,29 @@ using System.Text; namespace Tensorflow.Eager { - public class Context + public class Context : IDisposable { + private IntPtr _handle; + public static int GRAPH_MODE = 0; public static int EAGER_MODE = 1; public int default_execution_mode; + public Context(ContextOptions opts, Status status) + { + _handle = c_api.TFE_NewContext(opts, status); + status.Check(true); + } + + public void Dispose() + { + c_api.TFE_DeleteContext(_handle); + } + public static implicit operator IntPtr(Context ctx) + { + return ctx._handle; + } } } diff --git a/src/TensorFlowNET.Core/Eager/ContextOptions.cs b/src/TensorFlowNET.Core/Eager/ContextOptions.cs index 8bc49c8d..46d40bf6 100644 --- a/src/TensorFlowNET.Core/Eager/ContextOptions.cs +++ b/src/TensorFlowNET.Core/Eager/ContextOptions.cs @@ -18,9 +18,9 @@ namespace Tensorflow.Eager c_api.TFE_DeleteContextOptions(_handle); } - public static implicit operator IntPtr(ContextOptions ctx) + public static implicit operator IntPtr(ContextOptions opts) { - return ctx._handle; + return opts._handle; } } } diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs index b8a88cf5..a6d9c8cd 100644 --- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs +++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs @@ -10,24 +10,106 @@ namespace Tensorflow /// /// Return a new options object. /// - /// + /// TFE_ContextOptions* [DllImport(TensorFlowLibName)] public static extern IntPtr TFE_NewContextOptions(); /// /// Destroy an options object. /// - /// + /// TFE_ContextOptions* [DllImport(TensorFlowLibName)] public static extern void TFE_DeleteContextOptions(IntPtr options); + /// + /// + /// + /// const TFE_ContextOptions* + /// TF_Status* + /// TFE_Context* [DllImport(TensorFlowLibName)] public static extern IntPtr TFE_NewContext(IntPtr opts, IntPtr status); + /// + /// + /// + /// TFE_Context* [DllImport(TensorFlowLibName)] public static extern void TFE_DeleteContext(IntPtr ctx); + /// + /// Execute the operation defined by 'op' and return handles to computed + /// tensors in `retvals`. + /// + /// TFE_Op* + /// TFE_TensorHandle** + /// int* + /// TF_Status* + [DllImport(TensorFlowLibName)] + public static extern void TFE_Execute(IntPtr op, IntPtr retvals, int[] num_retvals, IntPtr status); + + /// + /// + /// + /// TFE_Context* + /// const char* + /// TF_Status* + /// [DllImport(TensorFlowLibName)] public static extern IntPtr TFE_NewOp(IntPtr ctx, string op_or_function_name, IntPtr status); + + /// + /// + /// + /// TFE_Op* + [DllImport(TensorFlowLibName)] + public static extern void TFE_DeleteOp(IntPtr op); + + /// + /// + /// + /// TFE_Op* + /// const char* + /// TF_DataType + [DllImport(TensorFlowLibName)] + public static extern void TFE_OpSetAttrType(IntPtr op, string attr_name, TF_DataType value); + + /// + /// + /// + /// TFE_Op* + /// const char* + /// const int64_t* + /// const int + /// TF_Status* + [DllImport(TensorFlowLibName)] + public static extern void TFE_OpSetAttrShape(IntPtr op, string attr_name, long[] dims, int num_dims, Status out_status); + + /// + /// + /// + /// TFE_Op* + /// const char* + /// const void* + /// size_t + [DllImport(TensorFlowLibName)] + public static extern void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length); + + /// + /// + /// + /// TFE_Op* + /// TFE_TensorHandle* + /// TF_Status* + [DllImport(TensorFlowLibName)] + public static extern void TFE_OpAddInput(IntPtr op, IntPtr h, IntPtr status); + + /// + /// + /// + /// const tensorflow::Tensor& + /// TFE_TensorHandle* + [DllImport(TensorFlowLibName)] + public static extern IntPtr TFE_NewTensorHandle(IntPtr t); } } diff --git a/src/TensorFlowNET.Core/ops.name_scope.cs b/src/TensorFlowNET.Core/ops.name_scope.cs index 26b5e114..403d8897 100644 --- a/src/TensorFlowNET.Core/ops.name_scope.cs +++ b/src/TensorFlowNET.Core/ops.name_scope.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Eager; namespace Tensorflow { @@ -20,7 +21,7 @@ namespace Tensorflow _name = name; _default_name = default_name; _values = values; - _ctx = new Context(); + // _ctx = new Context(); } public string __enter__() diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs index 3e21d929..5d63a411 100644 --- a/src/TensorFlowNET.Core/tf.cs +++ b/src/TensorFlowNET.Core/tf.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Eager; namespace Tensorflow { @@ -12,7 +13,7 @@ namespace Tensorflow public static TF_DataType float64 = TF_DataType.TF_DOUBLE; public static TF_DataType chars = TF_DataType.TF_STRING; - public static Context context = new Context(); + public static Context context; public static Graph g = new Graph(c_api.TF_NewGraph()); @@ -28,6 +29,7 @@ namespace Tensorflow public static void enable_eager_execution() { + // contex = new Context(); context.default_execution_mode = Context.EAGER_MODE; } diff --git a/test/TensorFlowNET.UnitTest/Eager/CApiVariableTest.cs b/test/TensorFlowNET.UnitTest/Eager/CApiVariableTest.cs index 12830db6..1a5cb1a5 100644 --- a/test/TensorFlowNET.UnitTest/Eager/CApiVariableTest.cs +++ b/test/TensorFlowNET.UnitTest/Eager/CApiVariableTest.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Text; using Tensorflow; +using Tensorflow.Eager; namespace TensorFlowNET.UnitTest.Eager { @@ -13,16 +14,67 @@ namespace TensorFlowNET.UnitTest.Eager public class CApiVariableTest : CApiTest, IDisposable { Status status = new Status(); + ContextOptions opts = new ContextOptions(); + Context ctx; [TestMethod] public void Variables() { + ctx = new Context(opts, status); + ASSERT_EQ(TF_Code.TF_OK, status.Code); + opts.Dispose(); + var var_handle = CreateVariable(ctx, 12.0F); + ASSERT_EQ(TF_OK, TF_GetCode(status)); + } + + private IntPtr CreateVariable(Context ctx, float value) + { + // Create the variable handle. + var op = c_api.TFE_NewOp(ctx, "VarHandleOp", status); + if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; + + c_api.TFE_OpSetAttrType(op, "dtype", TF_DataType.TF_FLOAT); + c_api.TFE_OpSetAttrShape(op, "shape", new long[0], 0, status); + c_api.TFE_OpSetAttrString(op, "container", "", 0); + c_api.TFE_OpSetAttrString(op, "shared_name", "", 0); + if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; + var var_handle = IntPtr.Zero; + int[] num_retvals = { 1 }; + c_api.TFE_Execute(op, var_handle, num_retvals, status); + c_api.TFE_DeleteOp(op); + if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; + ASSERT_EQ(1, num_retvals); + + // Assign 'value' to it. + op = c_api.TFE_NewOp(ctx, "AssignVariableOp", status); + if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; + c_api.TFE_OpSetAttrType(op, "dtype", TF_FLOAT); + c_api.TFE_OpAddInput(op, var_handle, status); + + // Convert 'value' to a TF_Tensor then a TFE_TensorHandle. + var t = new Tensor(value); + + var value_handle = c_api.TFE_NewTensorHandle(t); + if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; + + c_api.TFE_OpAddInput(op, value_handle, status); + if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; + + num_retvals = new int[] { 0 }; + c_api.TFE_Execute(op, IntPtr.Zero, num_retvals, status); + c_api.TFE_DeleteOp(op); + if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; + ASSERT_EQ(0, num_retvals); + + return var_handle; } public void Dispose() { - + status.Dispose(); + opts.Dispose(); + ctx.Dispose(); } } }