| @@ -4,13 +4,29 @@ using System.Text; | |||||
| namespace Tensorflow.Eager | namespace Tensorflow.Eager | ||||
| { | { | ||||
| public class Context | |||||
| public class Context : IDisposable | |||||
| { | { | ||||
| private IntPtr _handle; | |||||
| public static int GRAPH_MODE = 0; | public static int GRAPH_MODE = 0; | ||||
| public static int EAGER_MODE = 1; | public static int EAGER_MODE = 1; | ||||
| public int default_execution_mode; | public int default_execution_mode; | ||||
| public Context(ContextOptions opts, Status status) | |||||
| { | |||||
| _handle = c_api.TFE_NewContext(opts, status); | |||||
| status.Check(true); | |||||
| } | |||||
| public void Dispose() | |||||
| { | |||||
| c_api.TFE_DeleteContext(_handle); | |||||
| } | |||||
| public static implicit operator IntPtr(Context ctx) | |||||
| { | |||||
| return ctx._handle; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -18,9 +18,9 @@ namespace Tensorflow.Eager | |||||
| c_api.TFE_DeleteContextOptions(_handle); | c_api.TFE_DeleteContextOptions(_handle); | ||||
| } | } | ||||
| public static implicit operator IntPtr(ContextOptions ctx) | |||||
| public static implicit operator IntPtr(ContextOptions opts) | |||||
| { | { | ||||
| return ctx._handle; | |||||
| return opts._handle; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -10,24 +10,106 @@ namespace Tensorflow | |||||
| /// <summary> | /// <summary> | ||||
| /// Return a new options object. | /// Return a new options object. | ||||
| /// </summary> | /// </summary> | ||||
| /// <returns></returns> | |||||
| /// <returns>TFE_ContextOptions*</returns> | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TFE_NewContextOptions(); | public static extern IntPtr TFE_NewContextOptions(); | ||||
| /// <summary> | /// <summary> | ||||
| /// Destroy an options object. | /// Destroy an options object. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="options"></param> | |||||
| /// <param name="options">TFE_ContextOptions*</param> | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFE_DeleteContextOptions(IntPtr options); | public static extern void TFE_DeleteContextOptions(IntPtr options); | ||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="opts">const TFE_ContextOptions*</param> | |||||
| /// <param name="status">TF_Status*</param> | |||||
| /// <returns>TFE_Context*</returns> | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TFE_NewContext(IntPtr opts, IntPtr status); | public static extern IntPtr TFE_NewContext(IntPtr opts, IntPtr status); | ||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="ctx">TFE_Context*</param> | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFE_DeleteContext(IntPtr ctx); | public static extern void TFE_DeleteContext(IntPtr ctx); | ||||
| /// <summary> | |||||
| /// Execute the operation defined by 'op' and return handles to computed | |||||
| /// tensors in `retvals`. | |||||
| /// </summary> | |||||
| /// <param name="op">TFE_Op*</param> | |||||
| /// <param name="retvals">TFE_TensorHandle**</param> | |||||
| /// <param name="num_retvals">int*</param> | |||||
| /// <param name="status">TF_Status*</param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TFE_Execute(IntPtr op, IntPtr retvals, int[] num_retvals, IntPtr status); | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="ctx">TFE_Context*</param> | |||||
| /// <param name="op_or_function_name">const char*</param> | |||||
| /// <param name="status">TF_Status*</param> | |||||
| /// <returns></returns> | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TFE_NewOp(IntPtr ctx, string op_or_function_name, IntPtr status); | public static extern IntPtr TFE_NewOp(IntPtr ctx, string op_or_function_name, IntPtr status); | ||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="op">TFE_Op*</param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TFE_DeleteOp(IntPtr op); | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="op">TFE_Op*</param> | |||||
| /// <param name="attr_name">const char*</param> | |||||
| /// <param name="value">TF_DataType</param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TFE_OpSetAttrType(IntPtr op, string attr_name, TF_DataType value); | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="op">TFE_Op*</param> | |||||
| /// <param name="attr_name">const char*</param> | |||||
| /// <param name="dims">const int64_t*</param> | |||||
| /// <param name="num_dims">const int</param> | |||||
| /// <param name="out_status">TF_Status*</param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TFE_OpSetAttrShape(IntPtr op, string attr_name, long[] dims, int num_dims, Status out_status); | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="op">TFE_Op*</param> | |||||
| /// <param name="attr_name">const char*</param> | |||||
| /// <param name="value">const void*</param> | |||||
| /// <param name="length">size_t</param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length); | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="op">TFE_Op*</param> | |||||
| /// <param name="h">TFE_TensorHandle*</param> | |||||
| /// <param name="status">TF_Status*</param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TFE_OpAddInput(IntPtr op, IntPtr h, IntPtr status); | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="t">const tensorflow::Tensor&</param> | |||||
| /// <returns>TFE_TensorHandle*</returns> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern IntPtr TFE_NewTensorHandle(IntPtr t); | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,6 +1,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Eager; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -20,7 +21,7 @@ namespace Tensorflow | |||||
| _name = name; | _name = name; | ||||
| _default_name = default_name; | _default_name = default_name; | ||||
| _values = values; | _values = values; | ||||
| _ctx = new Context(); | |||||
| // _ctx = new Context(); | |||||
| } | } | ||||
| public string __enter__() | public string __enter__() | ||||
| @@ -1,6 +1,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Eager; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -12,7 +13,7 @@ namespace Tensorflow | |||||
| public static TF_DataType float64 = TF_DataType.TF_DOUBLE; | public static TF_DataType float64 = TF_DataType.TF_DOUBLE; | ||||
| public static TF_DataType chars = TF_DataType.TF_STRING; | public static TF_DataType chars = TF_DataType.TF_STRING; | ||||
| public static Context context = new Context(); | |||||
| public static Context context; | |||||
| public static Graph g = new Graph(c_api.TF_NewGraph()); | public static Graph g = new Graph(c_api.TF_NewGraph()); | ||||
| @@ -28,6 +29,7 @@ namespace Tensorflow | |||||
| public static void enable_eager_execution() | public static void enable_eager_execution() | ||||
| { | { | ||||
| // contex = new Context(); | |||||
| context.default_execution_mode = Context.EAGER_MODE; | context.default_execution_mode = Context.EAGER_MODE; | ||||
| } | } | ||||
| @@ -3,6 +3,7 @@ using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Eager; | |||||
| namespace TensorFlowNET.UnitTest.Eager | namespace TensorFlowNET.UnitTest.Eager | ||||
| { | { | ||||
| @@ -13,16 +14,67 @@ namespace TensorFlowNET.UnitTest.Eager | |||||
| public class CApiVariableTest : CApiTest, IDisposable | public class CApiVariableTest : CApiTest, IDisposable | ||||
| { | { | ||||
| Status status = new Status(); | Status status = new Status(); | ||||
| ContextOptions opts = new ContextOptions(); | |||||
| Context ctx; | |||||
| [TestMethod] | [TestMethod] | ||||
| public void Variables() | public void Variables() | ||||
| { | { | ||||
| ctx = new Context(opts, status); | |||||
| ASSERT_EQ(TF_Code.TF_OK, status.Code); | |||||
| opts.Dispose(); | |||||
| var var_handle = CreateVariable(ctx, 12.0F); | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status)); | |||||
| } | |||||
| private IntPtr CreateVariable(Context ctx, float value) | |||||
| { | |||||
| // Create the variable handle. | |||||
| var op = c_api.TFE_NewOp(ctx, "VarHandleOp", status); | |||||
| if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||||
| c_api.TFE_OpSetAttrType(op, "dtype", TF_DataType.TF_FLOAT); | |||||
| c_api.TFE_OpSetAttrShape(op, "shape", new long[0], 0, status); | |||||
| c_api.TFE_OpSetAttrString(op, "container", "", 0); | |||||
| c_api.TFE_OpSetAttrString(op, "shared_name", "", 0); | |||||
| if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||||
| var var_handle = IntPtr.Zero; | |||||
| int[] num_retvals = { 1 }; | |||||
| c_api.TFE_Execute(op, var_handle, num_retvals, status); | |||||
| c_api.TFE_DeleteOp(op); | |||||
| if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||||
| ASSERT_EQ(1, num_retvals); | |||||
| // Assign 'value' to it. | |||||
| op = c_api.TFE_NewOp(ctx, "AssignVariableOp", status); | |||||
| if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||||
| c_api.TFE_OpSetAttrType(op, "dtype", TF_FLOAT); | |||||
| c_api.TFE_OpAddInput(op, var_handle, status); | |||||
| // Convert 'value' to a TF_Tensor then a TFE_TensorHandle. | |||||
| var t = new Tensor(value); | |||||
| var value_handle = c_api.TFE_NewTensorHandle(t); | |||||
| if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||||
| c_api.TFE_OpAddInput(op, value_handle, status); | |||||
| if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||||
| num_retvals = new int[] { 0 }; | |||||
| c_api.TFE_Execute(op, IntPtr.Zero, num_retvals, status); | |||||
| c_api.TFE_DeleteOp(op); | |||||
| if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | |||||
| ASSERT_EQ(0, num_retvals); | |||||
| return var_handle; | |||||
| } | } | ||||
| public void Dispose() | public void Dispose() | ||||
| { | { | ||||
| status.Dispose(); | |||||
| opts.Dispose(); | |||||
| ctx.Dispose(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||