| @@ -1,5 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Runtime.InteropServices; | |||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -14,16 +15,16 @@ namespace Tensorflow | |||||
| private Status status = new Status(); | private Status status = new Status(); | ||||
| public string name { get; } | |||||
| public string optype { get; } | |||||
| public string device { get; } | |||||
| public int NumOutputs { get; } | |||||
| public TF_DataType OutputType { get; } | |||||
| public int OutputListLength { get; } | |||||
| public int NumInputs { get; } | |||||
| public int NumConsumers { get; } | |||||
| public int NumControlInputs { get; } | |||||
| public int NumControlOutputs { get; } | |||||
| public string name => c_api.StringPiece(c_api.TF_OperationName(_handle)); | |||||
| public string optype => c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | |||||
| public string device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | |||||
| public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); | |||||
| public TF_DataType OutputType => c_api.TF_OperationOutputType(new TF_Output(_handle, 0)); | |||||
| public int OutputListLength => c_api.TF_OperationOutputListLength(_handle, "output", status); | |||||
| public int NumInputs => c_api.TF_OperationNumInputs(_handle); | |||||
| public int NumConsumers => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, 0)); | |||||
| public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); | |||||
| public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | |||||
| private Tensor[] _outputs; | private Tensor[] _outputs; | ||||
| public Tensor[] outputs => _outputs; | public Tensor[] outputs => _outputs; | ||||
| @@ -35,17 +36,6 @@ namespace Tensorflow | |||||
| return; | return; | ||||
| _handle = handle; | _handle = handle; | ||||
| name = c_api.TF_OperationName(_handle); | |||||
| optype = c_api.TF_OperationOpType(_handle); | |||||
| device = "";// c_api.TF_OperationDevice(_handle); | |||||
| NumOutputs = c_api.TF_OperationNumOutputs(_handle); | |||||
| OutputType = c_api.TF_OperationOutputType(new TF_Output(_handle, 0)); | |||||
| OutputListLength = c_api.TF_OperationOutputListLength(_handle, "output", status); | |||||
| NumInputs = c_api.TF_OperationNumInputs(_handle); | |||||
| NumConsumers = c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, 0)); | |||||
| NumControlInputs = c_api.TF_OperationNumControlInputs(_handle); | |||||
| NumControlOutputs = c_api.TF_OperationNumControlOutputs(_handle); | |||||
| } | } | ||||
| public Operation(Graph g, string opType, string oper_name) | public Operation(Graph g, string opType, string oper_name) | ||||
| @@ -62,8 +52,8 @@ namespace Tensorflow | |||||
| Graph = g; | Graph = g; | ||||
| _id_value = Graph._next_id(); | _id_value = Graph._next_id(); | ||||
| _handle = ops._create_c_op(g, node_def, inputs); | _handle = ops._create_c_op(g, node_def, inputs); | ||||
| NumOutputs = c_api.TF_OperationNumOutputs(_handle); | |||||
| _outputs = new Tensor[NumOutputs]; | _outputs = new Tensor[NumOutputs]; | ||||
| for (int i = 0; i < NumOutputs; i++) | for (int i = 0; i < NumOutputs; i++) | ||||
| @@ -38,7 +38,7 @@ namespace Tensorflow | |||||
| public static extern IntPtr TF_NewOperation(IntPtr graph, string opType, string oper_name); | public static extern IntPtr TF_NewOperation(IntPtr graph, string opType, string oper_name); | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern string TF_OperationDevice(IntPtr oper); | |||||
| public static extern IntPtr TF_OperationDevice(IntPtr oper); | |||||
| /// <summary> | /// <summary> | ||||
| /// Sets `output_attr_value` to the binary-serialized AttrValue proto | /// Sets `output_attr_value` to the binary-serialized AttrValue proto | ||||
| @@ -50,13 +50,13 @@ namespace Tensorflow | |||||
| public static extern int TF_OperationGetAttrValueProto(IntPtr oper, string attr_name, IntPtr output_attr_value, IntPtr status); | public static extern int TF_OperationGetAttrValueProto(IntPtr oper, string attr_name, IntPtr output_attr_value, IntPtr status); | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern string TF_OperationName(IntPtr oper); | |||||
| public static extern IntPtr TF_OperationName(IntPtr oper); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern int TF_OperationNumInputs(IntPtr oper); | public static extern int TF_OperationNumInputs(IntPtr oper); | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern string TF_OperationOpType(IntPtr oper); | |||||
| public static extern IntPtr TF_OperationOpType(IntPtr oper); | |||||
| /// <summary> | /// <summary> | ||||
| /// Get the number of control inputs to an operation. | /// Get the number of control inputs to an operation. | ||||
| @@ -30,12 +30,12 @@ namespace Tensorflow | |||||
| // Add inputs | // Add inputs | ||||
| if(inputs != null && inputs.Count > 0) | if(inputs != null && inputs.Count > 0) | ||||
| { | { | ||||
| /*foreach (var op_input in inputs) | |||||
| foreach (var op_input in inputs) | |||||
| { | { | ||||
| c_api.TF_AddInput(op_desc, op_input._as_tf_output()); | c_api.TF_AddInput(op_desc, op_input._as_tf_output()); | ||||
| }*/ | |||||
| } | |||||
| c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count); | |||||
| //c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count); | |||||
| } | } | ||||
| var status = new Status(); | var status = new Status(); | ||||
| @@ -15,7 +15,7 @@ namespace Tensorflow | |||||
| /// <summary> | /// <summary> | ||||
| /// Error message | /// Error message | ||||
| /// </summary> | /// </summary> | ||||
| public string Message => c_api.TF_Message(_handle); | |||||
| public string Message => c_api.StringPiece(c_api.TF_Message(_handle)); | |||||
| /// <summary> | /// <summary> | ||||
| /// Error code | /// Error code | ||||
| @@ -12,7 +12,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="s"></param> | /// <param name="s"></param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static unsafe extern void TF_DeleteStatus(IntPtr s); | |||||
| public static extern void TF_DeleteStatus(IntPtr s); | |||||
| /// <summary> | /// <summary> | ||||
| /// Return the code record in *s. | /// Return the code record in *s. | ||||
| @@ -20,7 +20,7 @@ namespace Tensorflow | |||||
| /// <param name="s"></param> | /// <param name="s"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern unsafe TF_Code TF_GetCode(IntPtr s); | |||||
| public static extern TF_Code TF_GetCode(IntPtr s); | |||||
| /// <summary> | /// <summary> | ||||
| /// Return a pointer to the (null-terminated) error message in *s. | /// Return a pointer to the (null-terminated) error message in *s. | ||||
| @@ -30,7 +30,7 @@ namespace Tensorflow | |||||
| /// <param name="s"></param> | /// <param name="s"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern unsafe string TF_Message(IntPtr s); | |||||
| public static extern IntPtr TF_Message(IntPtr s); | |||||
| /// <summary> | /// <summary> | ||||
| /// Return a new status object. | /// Return a new status object. | ||||
| @@ -12,20 +12,27 @@ namespace Tensorflow | |||||
| /// The API leans towards simplicity and uniformity instead of convenience | /// The API leans towards simplicity and uniformity instead of convenience | ||||
| /// since most usage will be by language specific wrappers. | /// since most usage will be by language specific wrappers. | ||||
| /// | /// | ||||
| /// The params type mapping between .net and c_api | |||||
| /// The params type mapping between c_api and .NET | |||||
| /// TF_XX** => ref IntPtr (TF_Operation** op) => (ref IntPtr op) | /// TF_XX** => ref IntPtr (TF_Operation** op) => (ref IntPtr op) | ||||
| /// TF_XX* => IntPtr (TF_Graph* graph) => (IntPtr graph) | /// TF_XX* => IntPtr (TF_Graph* graph) => (IntPtr graph) | ||||
| /// struct => struct (TF_Output output) => (TF_Output output) | /// struct => struct (TF_Output output) => (TF_Output output) | ||||
| /// struct* => struct (TF_Output* output) => (TF_Output[] output) | |||||
| /// const char* => string | /// const char* => string | ||||
| /// int32_t => int | /// int32_t => int | ||||
| /// int64_t* => long[] | /// int64_t* => long[] | ||||
| /// size_t* => unlong[] | /// size_t* => unlong[] | ||||
| /// void* => IntPtr | /// void* => IntPtr | ||||
| /// string => IntPtr c_api.StringPiece(IntPtr) | |||||
| /// </summary> | /// </summary> | ||||
| public static partial class c_api | public static partial class c_api | ||||
| { | { | ||||
| public const string TensorFlowLibName = "tensorflow"; | public const string TensorFlowLibName = "tensorflow"; | ||||
| public static string StringPiece(IntPtr handle) | |||||
| { | |||||
| return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle); | |||||
| } | |||||
| public delegate void Deallocator(IntPtr data, IntPtr size, ref bool deallocator); | public delegate void Deallocator(IntPtr data, IntPtr size, ref bool deallocator); | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| @@ -37,7 +37,7 @@ namespace Tensorflow | |||||
| context.default_execution_mode = Context.EAGER_MODE; | context.default_execution_mode = Context.EAGER_MODE; | ||||
| } | } | ||||
| public static string VERSION => Marshal.PtrToStringAnsi(c_api.TF_Version()); | |||||
| public static string VERSION => c_api.StringPiece(c_api.TF_Version()); | |||||
| public static Graph get_default_graph() | public static Graph get_default_graph() | ||||
| { | { | ||||
| @@ -39,17 +39,14 @@ namespace TensorFlowNET.UnitTest | |||||
| // Test not found errors in TF_Operation*() query functions. | // Test not found errors in TF_Operation*() query functions. | ||||
| Assert.AreEqual(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s)); | Assert.AreEqual(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s)); | ||||
| Assert.AreEqual(TF_Code.TF_INVALID_ARGUMENT, s.Code); | Assert.AreEqual(TF_Code.TF_INVALID_ARGUMENT, s.Code); | ||||
| //Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s)); | |||||
| //Assert.AreEqual("Operation '' has no attr named 'missing'.", s.Message); | |||||
| Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s)); | |||||
| Assert.AreEqual("Operation 'feed' has no attr named 'missing'.", s.Message); | |||||
| // Make a constant oper with the scalar "3". | // Make a constant oper with the scalar "3". | ||||
| var three = c_test_util.ScalarConst(3, graph, s); | var three = c_test_util.ScalarConst(3, graph, s); | ||||
| // Add oper. | // Add oper. | ||||
| var add = c_test_util.Add(feed, three, graph, s); | var add = c_test_util.Add(feed, three, graph, s); | ||||
| NodeDef node_def = null; | |||||
| c_test_util.GetNodeDef(feed, ref node_def); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -43,7 +43,7 @@ namespace TensorFlowNET.UnitTest | |||||
| public void addInConstant() | public void addInConstant() | ||||
| { | { | ||||
| var a = tf.constant(4.0f); | var a = tf.constant(4.0f); | ||||
| var b = tf.placeholder(tf.float32); | |||||
| var b = tf.constant(5.0f); | |||||
| var c = tf.add(a, b); | var c = tf.add(a, b); | ||||
| using (var sess = tf.Session()) | using (var sess = tf.Session()) | ||||
| @@ -23,11 +23,13 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| var desc = c_api.TF_NewOperation(graph, "AddN", name); | var desc = c_api.TF_NewOperation(graph, "AddN", name); | ||||
| c_api.TF_AddInputList(desc, new TF_Output[] | |||||
| var inputs = new TF_Output[] | |||||
| { | { | ||||
| new TF_Output(l, 0), | new TF_Output(l, 0), | ||||
| new TF_Output(r, 0), | new TF_Output(r, 0), | ||||
| }, 2); | |||||
| }; | |||||
| c_api.TF_AddInputList(desc, inputs, inputs.Length); | |||||
| op = c_api.TF_FinishOperation(desc, s); | op = c_api.TF_FinishOperation(desc, s); | ||||
| s.Check(); | s.Check(); | ||||