Browse Source

added a lot of APIs related to Operation.

tags/v0.1.0-Tensor
Oceania2018 7 years ago
parent
commit
a23be5a673
10 changed files with 174 additions and 45 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/Buffers/Buffer.cs
  2. +7
    -0
      src/TensorFlowNET.Core/Buffers/c_api.buffer.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Graphs/Graph.cs
  4. +34
    -16
      src/TensorFlowNET.Core/Operations/Operation.cs
  5. +6
    -0
      src/TensorFlowNET.Core/Operations/TF_Output.cs
  6. +62
    -8
      src/TensorFlowNET.Core/Operations/c_api.ops.cs
  7. +11
    -11
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  8. +19
    -2
      test/TensorFlowNET.UnitTest/GraphTest.cs
  9. +1
    -1
      test/TensorFlowNET.UnitTest/TensorTest.cs
  10. +31
    -4
      test/TensorFlowNET.UnitTest/c_test_util.cs

+ 2
- 2
src/TensorFlowNET.Core/Buffers/Buffer.cs View File

@@ -8,7 +8,6 @@ namespace Tensorflow
public class Buffer public class Buffer
{ {
private IntPtr _handle; private IntPtr _handle;
public IntPtr Handle => _handle;


private TF_Buffer buffer; private TF_Buffer buffer;


@@ -21,7 +20,8 @@ namespace Tensorflow
_handle = handle; _handle = handle;
buffer = Marshal.PtrToStructure<TF_Buffer>(_handle); buffer = Marshal.PtrToStructure<TF_Buffer>(_handle);
Data = new byte[buffer.length]; Data = new byte[buffer.length];
Marshal.Copy(buffer.data, Data, 0, (int)buffer.length);
if (buffer.length > 0)
Marshal.Copy(buffer.data, Data, 0, (int)buffer.length);
} }
} }
} }

+ 7
- 0
src/TensorFlowNET.Core/Buffers/c_api.buffer.cs View File

@@ -7,6 +7,13 @@ namespace Tensorflow
{ {
public static partial class c_api public static partial class c_api
{ {
/// <summary>
/// Useful for passing *out* a protobuf.
/// </summary>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TF_NewBuffer();

[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TF_GetBuffer(TF_Buffer buffer); public static extern IntPtr TF_GetBuffer(TF_Buffer buffer);
} }


+ 1
- 1
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -64,7 +64,7 @@ namespace Tensorflow


if(obj is Tensor && allow_tensor) if(obj is Tensor && allow_tensor)
{ {
if ((obj as Tensor).graph.Equals(this))
if ((obj as Tensor).Graph.Equals(this))
{ {
return obj; return obj;
} }


+ 34
- 16
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -6,28 +6,37 @@ namespace Tensorflow
{ {
public class Operation public class Operation
{ {
public IntPtr Handle { get; }
private readonly IntPtr _handle;


private Graph _graph;
public Graph graph => _graph;
public IntPtr _c_op;
public Graph Graph { get; }
public int _id => _id_value; public int _id => _id_value;
private int _id_value; private int _id_value;
public string name;

private Status status = new Status();

public string name => c_api.TF_OperationName(_handle);
public string optype => c_api.TF_OperationOpType(_handle);
public string device => c_api.TF_OperationDevice(_handle);
public int NumOutputs => c_api.TF_OperationNumOutputs(_handle);
public TF_DataType OutputType => c_api.TF_OperationOutputType(new TF_Output(_handle, 0));
public int OutputListLength => c_api.TF_OperationOutputListLength(_handle, "output", status);
public int NumInputs => c_api.TF_OperationNumInputs(_handle);
public int NumConsumers => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, 0));
public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle);
public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle);

private Tensor[] _outputs; private Tensor[] _outputs;
public Tensor[] outputs => _outputs; public Tensor[] outputs => _outputs;
public Tensor[] inputs; public Tensor[] inputs;


public Operation(IntPtr handle) public Operation(IntPtr handle)
{ {
Handle = handle;
_handle = handle;
} }


public Operation(Graph g, string opType, string oper_name) public Operation(Graph g, string opType, string oper_name)
{ {
_graph = g;

var status = new Status();
Graph = g;


var desc = c_api.TF_NewOperation(g, opType, oper_name); var desc = c_api.TF_NewOperation(g, opType, oper_name);
c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_INT32); c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_INT32);
@@ -36,19 +45,18 @@ namespace Tensorflow


public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
{ {
_graph = g;
Graph = g;


_id_value = _graph._next_id();
_c_op = ops._create_c_op(g, node_def, inputs);
var num_outputs = c_api.TF_OperationNumOutputs(_c_op);
_id_value = Graph._next_id();
_handle = ops._create_c_op(g, node_def, inputs);


_outputs = new Tensor[num_outputs];
for (int i = 0; i < num_outputs; i++)
_outputs = new Tensor[NumOutputs];
for (int i = 0; i < NumOutputs; i++)
{ {
_outputs[i] = new Tensor(this, i, output_types[i]); _outputs[i] = new Tensor(this, i, output_types[i]);
} }


_graph._add_op(this);
Graph._add_op(this);
} }


public object get_attr(string name) public object get_attr(string name)
@@ -69,5 +77,15 @@ namespace Tensorflow


return ret; return ret;
} }

public static implicit operator Operation(IntPtr handle)
{
return new Operation(handle);
}

public static implicit operator IntPtr(Operation op)
{
return op._handle;
}
} }
} }

+ 6
- 0
src/TensorFlowNET.Core/Operations/TF_Output.cs View File

@@ -8,6 +8,12 @@ namespace Tensorflow
[StructLayout(LayoutKind.Sequential)] [StructLayout(LayoutKind.Sequential)]
public struct TF_Output public struct TF_Output
{ {
public TF_Output(IntPtr oper, int index)
{
this.oper = oper;
this.index = index;
}

public IntPtr oper; public IntPtr oper;
public int index; public int index;
} }


+ 62
- 8
src/TensorFlowNET.Core/Operations/c_api.ops.cs View File

@@ -12,7 +12,7 @@ namespace Tensorflow
/// </summary> /// </summary>
/// <returns></returns> /// <returns></returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static unsafe extern IntPtr TF_GetAllOpList();
public static extern IntPtr TF_GetAllOpList();


/// <summary> /// <summary>
/// For inputs that take a single tensor. /// For inputs that take a single tensor.
@@ -20,24 +20,78 @@ namespace Tensorflow
/// <param name="desc"></param> /// <param name="desc"></param>
/// <param name="input"></param> /// <param name="input"></param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static unsafe extern void TF_AddInput(IntPtr desc, TF_Output input);
public static extern void TF_AddInput(IntPtr desc, TF_Output input);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static unsafe extern IntPtr TF_FinishOperation(IntPtr desc, IntPtr status);
public static extern IntPtr TF_FinishOperation(IntPtr desc, IntPtr status);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static unsafe extern IntPtr TF_NewOperation(IntPtr graph, string opType, string oper_name);
public static extern IntPtr TF_NewOperation(IntPtr graph, string opType, string oper_name);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern unsafe int TF_OperationNumOutputs(IntPtr oper);
public static extern string TF_OperationDevice(IntPtr oper);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern unsafe void TF_SetAttrValueProto(IntPtr desc, string attr_name, IntPtr proto, UIntPtr proto_len, IntPtr status);
public static extern string TF_OperationName(IntPtr oper);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern unsafe void TF_SetAttrTensor(IntPtr desc, string attr_name, IntPtr value, IntPtr status);
public static extern int TF_OperationNumInputs(IntPtr oper);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern unsafe void TF_SetAttrType(IntPtr desc, string attr_name, TF_DataType value);
public static extern string TF_OperationOpType(IntPtr oper);

/// <summary>
/// Get the number of control inputs to an operation.
/// </summary>
/// <param name="oper"></param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern int TF_OperationNumControlInputs(IntPtr oper);

/// <summary>
/// Get the number of operations that have `*oper` as a control input.
/// </summary>
/// <param name="oper"></param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern int TF_OperationNumControlOutputs(IntPtr oper);

[DllImport(TensorFlowLibName)]
public static extern int TF_OperationNumOutputs(IntPtr oper);

/// <summary>
/// Get the number of current consumers of a specific output of an
/// operation. Note that this number can change when new operations
/// are added to the graph.
/// </summary>
/// <param name="oper_out"></param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern int TF_OperationOutputNumConsumers(TF_Output oper_out);

[DllImport(TensorFlowLibName)]
public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out);

[DllImport(TensorFlowLibName)]
public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, IntPtr status);

[DllImport(TensorFlowLibName)]
public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, IntPtr proto, UIntPtr proto_len, IntPtr status);

/// <summary>
/// Set `num_dims` to -1 to represent "unknown rank".
/// </summary>
/// <param name="desc"></param>
/// <param name="attr_name"></param>
/// <param name="dims"></param>
/// <param name="num_dims"></param>
[DllImport(TensorFlowLibName)]
public static extern void TF_SetAttrShape(IntPtr desc, string attr_name, long[] dims, int num_dims);

[DllImport(TensorFlowLibName)]
public static extern void TF_SetAttrTensor(IntPtr desc, string attr_name, IntPtr value, IntPtr status);

[DllImport(TensorFlowLibName)]
public static extern void TF_SetAttrType(IntPtr desc, string attr_name, TF_DataType value);
} }
} }

+ 11
- 11
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -13,9 +13,9 @@ namespace Tensorflow
/// </summary> /// </summary>
public class Tensor public class Tensor
{ {
public IntPtr Handle { get; }
private readonly IntPtr _handle;


public Graph graph => op.graph;
public Graph Graph => op.Graph;
public Operation op { get; } public Operation op { get; }


public string name; public string name;
@@ -46,7 +46,7 @@ namespace Tensorflow


public Tensor(IntPtr handle) public Tensor(IntPtr handle)
{ {
Handle = handle;
_handle = handle;
dtype = c_api.TF_TensorType(handle); dtype = c_api.TF_TensorType(handle);
rank = c_api.TF_NumDims(handle); rank = c_api.TF_NumDims(handle);
bytesize = c_api.TF_TensorByteSize(handle); bytesize = c_api.TF_TensorByteSize(handle);
@@ -60,16 +60,16 @@ namespace Tensorflow


public Tensor(NDArray nd) public Tensor(NDArray nd)
{ {
Handle = Allocate(nd);
dtype = c_api.TF_TensorType(Handle);
rank = c_api.TF_NumDims(Handle);
bytesize = c_api.TF_TensorByteSize(Handle);
buffer = c_api.TF_TensorData(Handle);
_handle = Allocate(nd);
dtype = c_api.TF_TensorType(_handle);
rank = c_api.TF_NumDims(_handle);
bytesize = c_api.TF_TensorByteSize(_handle);
buffer = c_api.TF_TensorData(_handle);
dataTypeSize = c_api.TF_DataTypeSize(dtype); dataTypeSize = c_api.TF_DataTypeSize(dtype);


shape = new long[rank]; shape = new long[rank];
for (int i = 0; i < rank; i++) for (int i = 0; i < rank; i++)
shape[i] = c_api.TF_Dim(Handle, i);
shape[i] = c_api.TF_Dim(_handle, i);
} }


private IntPtr Allocate(NDArray nd) private IntPtr Allocate(NDArray nd)
@@ -117,7 +117,7 @@ namespace Tensorflow


public TF_Output _as_tf_output() public TF_Output _as_tf_output()
{ {
return c_api_util.tf_output(op._c_op, value_index);
return c_api_util.tf_output(op, value_index);
} }


public T[] Data<T>() public T[] Data<T>()
@@ -162,7 +162,7 @@ namespace Tensorflow


public static implicit operator IntPtr(Tensor tensor) public static implicit operator IntPtr(Tensor tensor)
{ {
return tensor.Handle;
return tensor._handle;
} }
} }
} }

+ 19
- 2
test/TensorFlowNET.UnitTest/GraphTest.cs View File

@@ -10,9 +10,26 @@ namespace TensorFlowNET.UnitTest
public class GraphTest public class GraphTest
{ {
[TestMethod] [TestMethod]
public void ConstructGraph()
public void Graph()
{ {
var g = tf.Graph();
var s = new Status();
var graph = tf.get_default_graph();

// Make a placeholder operation.
var feed = c_test_util.Placeholder(graph, s);
Assert.AreEqual("feed", feed.name);
Assert.AreEqual("Placeholder", feed.optype);
//Assert.AreEqual("", feed.device);
Assert.AreEqual(1, feed.NumOutputs);
Assert.AreEqual(TF_DataType.TF_INT32, feed.OutputType);
Assert.AreEqual(1, feed.OutputListLength);
Assert.AreEqual(0, feed.NumInputs);
Assert.AreEqual(0, feed.NumConsumers);
Assert.AreEqual(0, feed.NumControlInputs);
Assert.AreEqual(0, feed.NumControlOutputs);

var attr_value = new AttrValue();
c_test_util.GetAttrValue(feed, "dtype", attr_value, s);
} }
} }
} }

+ 1
- 1
test/TensorFlowNET.UnitTest/TensorTest.cs View File

@@ -87,7 +87,7 @@ namespace TensorFlowNET.UnitTest
// Test for a scalar. // Test for a scalar.
var three = c_test_util.ScalarConst(3, graph, s); var three = c_test_util.ScalarConst(3, graph, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK); Assert.IsTrue(s.Code == TF_Code.TF_OK);
var three_out_0 = new TF_Output { oper = three.Handle };
var three_out_0 = new TF_Output { oper = three };
num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s); num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s);
Assert.AreEqual(0, num_dims); Assert.AreEqual(0, num_dims);
} }


+ 31
- 4
test/TensorFlowNET.UnitTest/c_test_util.cs View File

@@ -3,15 +3,42 @@ using System.Collections.Generic;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Text; using System.Text;
using Tensorflow; using Tensorflow;
using Buffer = Tensorflow.Buffer;


namespace TensorFlowNET.UnitTest namespace TensorFlowNET.UnitTest
{ {
public static class c_test_util public static class c_test_util
{ {
public static void ConstHelper(Tensor t, Graph graph, Status s, string name, ref IntPtr op)
public static bool GetAttrValue(Operation oper, string attr_name, AttrValue attr_value, Status s)
{
var buffer = c_api.TF_NewBuffer();

return s.Code == TF_Code.TF_OK;
}

public static void PlaceholderHelper(Graph graph, Status s, string name, TF_DataType dtype, long[] dims, ref Operation op)
{
var desc = c_api.TF_NewOperation(graph, "Placeholder", name);
c_api.TF_SetAttrType(desc, "dtype", dtype);
if(dims != null)
{
c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length);
}
op = c_api.TF_FinishOperation(desc, s);
s.Check();
}

public static Operation Placeholder(Graph graph, Status s, string name = "feed", TF_DataType dtype = TF_DataType.TF_INT32, long[] dims = null)
{
Operation op = null;
PlaceholderHelper(graph, s, name, dtype, dims, ref op);
return op;
}

public static void ConstHelper(Tensor t, Graph graph, Status s, string name, ref Operation op)
{ {
var desc = c_api.TF_NewOperation(graph, "Const", name); var desc = c_api.TF_NewOperation(graph, "Const", name);
c_api.TF_SetAttrTensor(desc, "value", t.Handle, s);
c_api.TF_SetAttrTensor(desc, "value", t, s);
s.Check(); s.Check();
c_api.TF_SetAttrType(desc, "dtype", t.dtype); c_api.TF_SetAttrType(desc, "dtype", t.dtype);
op = c_api.TF_FinishOperation(desc, s); op = c_api.TF_FinishOperation(desc, s);
@@ -24,9 +51,9 @@ namespace TensorFlowNET.UnitTest


public static Operation Const(Tensor t, Graph graph, Status s, string name) public static Operation Const(Tensor t, Graph graph, Status s, string name)
{ {
IntPtr op = IntPtr.Zero;
Operation op = null;
ConstHelper(t, graph, s, name, ref op); ConstHelper(t, graph, s, name, ref op);
return new Operation(op);
return op;
} }


public static Operation ScalarConst(int v, Graph graph, Status s, string name = "Const") public static Operation ScalarConst(int v, Graph graph, Status s, string name = "Const")


Loading…
Cancel
Save