Browse Source

add implicit for Graph, Operation, Tensor, Status.

tags/v0.1.0-Tensor
Oceania2018 7 years ago
parent
commit
9e42e3c67f
20 changed files with 282 additions and 127 deletions
  1. +9
    -11
      src/TensorFlowNET.Core/Buffers/Buffer.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Buffers/TF_Buffer.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Buffers/c_api.buffer.cs
  4. +7
    -3
      src/TensorFlowNET.Core/Graphs/Graph.cs
  5. +33
    -0
      src/TensorFlowNET.Core/Graphs/c_api.graph.cs
  6. +0
    -3
      src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
  7. +9
    -3
      src/TensorFlowNET.Core/Operations/Operation.cs
  8. +0
    -0
      src/TensorFlowNET.Core/Operations/TF_Input.cs
  9. +0
    -0
      src/TensorFlowNET.Core/Operations/TF_Output.cs
  10. +13
    -6
      src/TensorFlowNET.Core/Operations/c_api.ops.cs
  11. +3
    -3
      src/TensorFlowNET.Core/Operations/ops.cs
  12. +10
    -10
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  13. +0
    -59
      src/TensorFlowNET.Core/Sessions/FeedDict.cs
  14. +1
    -1
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  15. +22
    -2
      src/TensorFlowNET.Core/Status/Status.cs
  16. +48
    -19
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  17. +11
    -1
      src/TensorFlowNET.Core/c_api.cs
  18. +12
    -3
      test/TensorFlowNET.UnitTest/OperationsTest.cs
  19. +65
    -1
      test/TensorFlowNET.UnitTest/TensorTest.cs
  20. +37
    -0
      test/TensorFlowNET.UnitTest/c_test_util.cs

+ 9
- 11
src/TensorFlowNET.Core/Buffers/Buffer.cs View File

@@ -9,21 +9,19 @@ namespace Tensorflow
{ {
private IntPtr _handle; private IntPtr _handle;
public IntPtr Handle => _handle; public IntPtr Handle => _handle;
//public TF_Buffer buffer => Marshal.PtrToStructure<TF_Buffer>(_handle);


public unsafe Buffer()
{
_handle = Marshal.AllocHGlobal(sizeof(TF_Buffer));
}
private TF_Buffer buffer;


public byte[] GetBuffer()
{
var buffer = Marshal.PtrToStructure<TF_Buffer>(_handle);
public byte[] Data;


var data = Marshal.AllocHGlobal(buffer.length);
//var bytes = c_api.TF_GetBuffer(buffer.data);
public int Length => (int)buffer.length;


return null;
public unsafe Buffer(IntPtr handle)
{
_handle = handle;
buffer = Marshal.PtrToStructure<TF_Buffer>(_handle);
Data = new byte[buffer.length];
Marshal.Copy(buffer.data, Data, 0, (int)buffer.length);
} }
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Buffers/TF_Buffer.cs View File

@@ -9,7 +9,7 @@ namespace Tensorflow
public struct TF_Buffer public struct TF_Buffer
{ {
public IntPtr data; public IntPtr data;
public int length;
public ulong length;
public IntPtr data_deallocator; public IntPtr data_deallocator;
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Buffers/c_api.buffer.cs View File

@@ -8,6 +8,6 @@ namespace Tensorflow
public static partial class c_api public static partial class c_api
{ {
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern string TF_GetBuffer(IntPtr buffer);
public static extern IntPtr TF_GetBuffer(TF_Buffer buffer);
} }
} }

+ 7
- 3
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -15,8 +15,7 @@ namespace Tensorflow
/// </summary> /// </summary>
public class Graph public class Graph
{ {
private IntPtr _c_graph;
public IntPtr Handle => _c_graph;
private IntPtr _handle;
private Dictionary<int, Operation> _nodes_by_id; private Dictionary<int, Operation> _nodes_by_id;
private Dictionary<string, Operation> _nodes_by_name; private Dictionary<string, Operation> _nodes_by_name;
private Dictionary<string, int> _names_in_use; private Dictionary<string, int> _names_in_use;
@@ -28,7 +27,7 @@ namespace Tensorflow


public Graph(IntPtr graph) public Graph(IntPtr graph)
{ {
this._c_graph = graph;
_handle = graph;
_nodes_by_id = new Dictionary<int, Operation>(); _nodes_by_id = new Dictionary<int, Operation>();
_nodes_by_name = new Dictionary<string, Operation>(); _nodes_by_name = new Dictionary<string, Operation>();
_names_in_use = new Dictionary<string, int>(); _names_in_use = new Dictionary<string, int>();
@@ -171,5 +170,10 @@ namespace Tensorflow
{ {
return _nodes_by_name.Values.Select(x => x).ToArray(); return _nodes_by_name.Values.Select(x => x).ToArray();
} }

public static implicit operator IntPtr(Graph graph)
{
return graph._handle;
}
} }
} }

+ 33
- 0
src/TensorFlowNET.Core/Graphs/c_api.graph.cs View File

@@ -10,6 +10,39 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status); public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status);


/// <summary>
/// Returns the shape of the Tensor referenced by `output` in `graph`
/// into `dims`. `dims` must be an array large enough to hold `num_dims`
/// entries (e.g., the return value of TF_GraphGetTensorNumDims).
/// </summary>
/// <param name="graph"></param>
/// <param name="output"></param>
/// <param name="dims"></param>
/// <param name="num_dims"></param>
/// <param name="status"></param>
[DllImport(TensorFlowLibName)]
public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, int[] dims, int num_dims, IntPtr status);

/// <summary>
/// Sets the shape of the Tensor referenced by `output` in `graph` to
/// the shape described by `dims` and `num_dims`.
/// </summary>
[DllImport(TensorFlowLibName)]
public static extern void TF_GraphSetTensorShape(IntPtr graph, TF_Output output, int[] dims, int num_dims, IntPtr status);

/// <summary>
/// Returns the number of dimensions of the Tensor referenced by `output`
/// in `graph`.
///
/// If the number of dimensions in the shape is unknown, returns -1.
/// </summary>
/// <param name="graph"></param>
/// <param name="output"></param>
/// <param name="status"></param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern int TF_GraphGetTensorNumDims(IntPtr graph, TF_Output output, IntPtr status);

[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static unsafe extern IntPtr TF_NewGraph(); public static unsafe extern IntPtr TF_NewGraph();
} }


+ 0
- 3
src/TensorFlowNET.Core/Operations/OpDefLibrary.cs View File

@@ -28,9 +28,6 @@ namespace Tensorflow
{ {
var op_def = _ops[op_type_name]; var op_def = _ops[op_type_name];


var status = new Status();
var buffer = new Buffer();

var g = ops.get_default_graph(); var g = ops.get_default_graph();


if (String.IsNullOrEmpty(name)) if (String.IsNullOrEmpty(name))


+ 9
- 3
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -1,12 +1,13 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using TF_DataType = Tensorflow.DataType;


namespace Tensorflow namespace Tensorflow
{ {
public class Operation public class Operation
{ {
public IntPtr Handle { get; }

private Graph _graph; private Graph _graph;
public Graph graph => _graph; public Graph graph => _graph;
public IntPtr _c_op; public IntPtr _c_op;
@@ -17,15 +18,20 @@ namespace Tensorflow
public Tensor[] outputs => _outputs; public Tensor[] outputs => _outputs;
public Tensor[] inputs; public Tensor[] inputs;


public Operation(IntPtr handle)
{
Handle = handle;
}

public Operation(Graph g, string opType, string oper_name) public Operation(Graph g, string opType, string oper_name)
{ {
_graph = g; _graph = g;


var status = new Status(); var status = new Status();


var desc = c_api.TF_NewOperation(g.Handle, opType, oper_name);
var desc = c_api.TF_NewOperation(g, opType, oper_name);
c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_INT32); c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_INT32);
c_api.TF_FinishOperation(desc, status.Handle);
c_api.TF_FinishOperation(desc, status);
} }


public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)


src/TensorFlowNET.Core/Graphs/TF_Input.cs → src/TensorFlowNET.Core/Operations/TF_Input.cs View File


src/TensorFlowNET.Core/Graphs/TF_Output.cs → src/TensorFlowNET.Core/Operations/TF_Output.cs View File


+ 13
- 6
src/TensorFlowNET.Core/Operations/c_api.ops.cs View File

@@ -7,30 +7,37 @@ namespace Tensorflow
{ {
public static partial class c_api public static partial class c_api
{ {
/// <summary>
/// Get the OpList of all OpDefs defined in this address space.
/// </summary>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static unsafe extern IntPtr TF_GetAllOpList();

/// <summary> /// <summary>
/// For inputs that take a single tensor. /// For inputs that take a single tensor.
/// </summary> /// </summary>
/// <param name="desc"></param> /// <param name="desc"></param>
/// <param name="input"></param> /// <param name="input"></param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static unsafe extern void TF_AddInput(TF_OperationDescription desc, TF_Output input);
public static unsafe extern void TF_AddInput(IntPtr desc, TF_Output input);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static unsafe extern IntPtr TF_FinishOperation(TF_OperationDescription desc, IntPtr status);
public static unsafe extern IntPtr TF_FinishOperation(IntPtr desc, IntPtr status);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static unsafe extern TF_OperationDescription TF_NewOperation(IntPtr graph, string opType, string oper_name);
public static unsafe extern IntPtr TF_NewOperation(IntPtr graph, string opType, string oper_name);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern unsafe int TF_OperationNumOutputs(IntPtr oper); public static extern unsafe int TF_OperationNumOutputs(IntPtr oper);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern unsafe void TF_SetAttrValueProto(TF_OperationDescription desc, string attr_name, IntPtr proto, UIntPtr proto_len, IntPtr status);
public static extern unsafe void TF_SetAttrValueProto(IntPtr desc, string attr_name, IntPtr proto, UIntPtr proto_len, IntPtr status);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern unsafe void TF_SetAttrTensor(TF_OperationDescription desc, string attr_name, IntPtr value, IntPtr status);
public static extern unsafe void TF_SetAttrTensor(IntPtr desc, string attr_name, IntPtr value, IntPtr status);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value);
public static extern unsafe void TF_SetAttrType(IntPtr desc, string attr_name, TF_DataType value);
} }
} }

+ 3
- 3
src/TensorFlowNET.Core/Operations/ops.cs View File

@@ -24,7 +24,7 @@ namespace Tensorflow


public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs) public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs)
{ {
var op_desc = c_api.TF_NewOperation(graph.Handle, node_def.Op, node_def.Name);
var op_desc = c_api.TF_NewOperation(graph, node_def.Op, node_def.Name);


// Add inputs // Add inputs
if(inputs != null) if(inputs != null)
@@ -45,12 +45,12 @@ namespace Tensorflow
var bytes = attr.Value.ToByteArray(); var bytes = attr.Value.ToByteArray();
var proto = Marshal.AllocHGlobal(bytes.Length); var proto = Marshal.AllocHGlobal(bytes.Length);
Marshal.Copy(bytes, 0, proto, bytes.Length); Marshal.Copy(bytes, 0, proto, bytes.Length);
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status.Handle);
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status);


if(status.Code != TF_Code.TF_OK) throw new Exception(status.Message); if(status.Code != TF_Code.TF_OK) throw new Exception(status.Message);
} }


var c_op = c_api.TF_FinishOperation(op_desc, status.Handle);
var c_op = c_api.TF_FinishOperation(op_desc, status);


if (status.Code != TF_Code.TF_OK) throw new Exception(status.Message); if (status.Code != TF_Code.TF_OK) throw new Exception(status.Message);




+ 10
- 10
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -30,7 +30,7 @@ namespace Tensorflow
_target = UTF8Encoding.UTF8.GetBytes(target); _target = UTF8Encoding.UTF8.GetBytes(target);
var opts = c_api.TF_NewSessionOptions(); var opts = c_api.TF_NewSessionOptions();
var status = new Status(); var status = new Status();
_session = c_api.TF_NewSession(_graph.Handle, opts, status.Handle);
_session = c_api.TF_NewSession(_graph, opts, status);


c_api.TF_DeleteSessionOptions(opts); c_api.TF_DeleteSessionOptions(opts);
} }
@@ -40,30 +40,30 @@ namespace Tensorflow
} }


public virtual object run(Tensor fetches, FeedDict feed_dict = null)
public virtual object run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null)
{ {
var result = _run(fetches, feed_dict); var result = _run(fetches, feed_dict);


return result; return result;
} }


private unsafe object _run(Tensor fetches, FeedDict feed_dict = null)
private unsafe object _run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null)
{ {
var feed_dict_tensor = new FeedDict();
var feed_dict_tensor = new Dictionary<Tensor, object>();


if (feed_dict != null) if (feed_dict != null)
{ {
NDArray np_val = null; NDArray np_val = null;
foreach (FeedValue feed in feed_dict)
foreach (var feed in feed_dict)
{ {
switch (feed.feed_val)
switch (feed.Value)
{ {
case float value: case float value:
np_val = np.asarray(value); np_val = np.asarray(value);
break; break;
} }


feed_dict_tensor[feed.feed] = np_val;
feed_dict_tensor[feed.Key] = np_val;
} }
} }


@@ -85,9 +85,9 @@ namespace Tensorflow
return fetch_handler.build_results(null, results); return fetch_handler.build_results(null, results);
} }


private object[] _do_run(List<Tensor> fetch_list, FeedDict feed_dict)
private object[] _do_run(List<Tensor> fetch_list, Dictionary<Tensor, object> feed_dict)
{ {
var feeds = feed_dict.items().Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value as NDArray))).ToArray();
var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value as NDArray))).ToArray();
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();


return _call_tf_sessionrun(feeds, fetches); return _call_tf_sessionrun(feeds, fetches);
@@ -113,7 +113,7 @@ namespace Tensorflow
target_opers: new IntPtr[] { }, target_opers: new IntPtr[] { },
ntargets: 0, ntargets: 0,
run_metadata: IntPtr.Zero, run_metadata: IntPtr.Zero,
status: status.Handle);
status: status);


var result = output_values.Select(x => c_api.TF_TensorData(x)) var result = output_values.Select(x => c_api.TF_TensorData(x))
.Select(x => (object)*(float*)x) .Select(x => (object)*(float*)x)


+ 0
- 59
src/TensorFlowNET.Core/Sessions/FeedDict.cs View File

@@ -1,59 +0,0 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class FeedDict : IEnumerable
{
private Dictionary<Tensor, object> feed_dict;

public FeedDict()
{
feed_dict = new Dictionary<Tensor, object>();
}

public object this[Tensor feed]
{
get
{
return feed_dict[feed];
}

set
{
feed_dict[feed] = value;
}
}

public FeedDict Add(Tensor feed, object value)
{
feed_dict.Add(feed, value);
return this;
}

public IEnumerator GetEnumerator()
{
foreach (KeyValuePair<Tensor, object> feed in feed_dict)
{
yield return new FeedValue
{
feed = feed.Key,
feed_val = feed.Value
};
}
}

public Dictionary<Tensor, object> items()
{
return feed_dict;
}
}

public struct FeedValue
{
public Tensor feed { get; set; }
public object feed_val { get; set; }
}
}

+ 1
- 1
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -15,7 +15,7 @@ namespace Tensorflow
private List<Tensor> _final_fetches = new List<Tensor>(); private List<Tensor> _final_fetches = new List<Tensor>();
private List<object> _targets = new List<object>(); private List<object> _targets = new List<object>();


public _FetchHandler(Graph graph, Tensor fetches, FeedDict feeds = null, object feed_handles = null)
public _FetchHandler(Graph graph, Tensor fetches, Dictionary<Tensor, object> feeds = null, object feed_handles = null)
{ {
_fetch_mapper = new _FetchMapper().for_fetch(fetches); _fetch_mapper = new _FetchMapper().for_fetch(fetches);
foreach(var fetch in _fetch_mapper.unique_fetches()) foreach(var fetch in _fetch_mapper.unique_fetches())


+ 22
- 2
src/TensorFlowNET.Core/Status/Status.cs View File

@@ -4,10 +4,13 @@ using System.Text;


namespace Tensorflow namespace Tensorflow
{ {
public class Status : IDisposable
/// <summary>
/// TF_Status holds error information. It either has an OK code, or
/// else an error code with an associated error message.
/// </summary>
public class Status
{ {
private readonly IntPtr _handle; private readonly IntPtr _handle;
public IntPtr Handle => _handle;


/// <summary> /// <summary>
/// Error message /// Error message
@@ -29,6 +32,23 @@ namespace Tensorflow
c_api.TF_SetStatus(_handle, code, msg); c_api.TF_SetStatus(_handle, code, msg);
} }


/// <summary>
/// Check status
/// Throw exception with error message if code != TF_OK
/// </summary>
public void Check()
{
if(Code != TF_Code.TF_OK)
{
throw new Exception(Message);
}
}

public static implicit operator IntPtr(Status status)
{
return status._handle;
}

public void Dispose() public void Dispose()
{ {
c_api.TF_DeleteStatus(_handle); c_api.TF_DeleteStatus(_handle);


+ 48
- 19
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -13,6 +13,8 @@ namespace Tensorflow
/// </summary> /// </summary>
public class Tensor public class Tensor
{ {
public IntPtr Handle { get; }

public Graph graph => op.graph; public Graph graph => op.graph;
public Operation op { get; } public Operation op { get; }


@@ -21,7 +23,6 @@ namespace Tensorflow
public int value_index { get; } public int value_index { get; }


public TF_DataType dtype { get; } public TF_DataType dtype { get; }
public IntPtr handle { get; }
public ulong bytesize { get; } public ulong bytesize { get; }
public ulong dataTypeSize { get;} public ulong dataTypeSize { get;}
public ulong size => bytesize / dataTypeSize; public ulong size => bytesize / dataTypeSize;
@@ -45,7 +46,7 @@ namespace Tensorflow


public Tensor(IntPtr handle) public Tensor(IntPtr handle)
{ {
this.handle = handle;
Handle = handle;
dtype = c_api.TF_TensorType(handle); dtype = c_api.TF_TensorType(handle);
rank = c_api.TF_NumDims(handle); rank = c_api.TF_NumDims(handle);
bytesize = c_api.TF_TensorByteSize(handle); bytesize = c_api.TF_TensorByteSize(handle);
@@ -59,33 +60,52 @@ namespace Tensorflow


public Tensor(NDArray nd) public Tensor(NDArray nd)
{ {
var data = Marshal.AllocHGlobal(sizeof(float) * nd.size);
Marshal.Copy(nd.Data<float>(), 0, data, nd.size);
var dataType = ToTFDataType(nd.dtype);
Handle = Allocate(nd);
dtype = c_api.TF_TensorType(Handle);
rank = c_api.TF_NumDims(Handle);
bytesize = c_api.TF_TensorByteSize(Handle);
buffer = c_api.TF_TensorData(Handle);
dataTypeSize = c_api.TF_DataTypeSize(dtype);

shape = new long[rank];
for (int i = 0; i < rank; i++)
shape[i] = c_api.TF_Dim(Handle, i);
}

private IntPtr Allocate(NDArray nd)
{
var dotHandle = Marshal.AllocHGlobal(nd.dtypesize * nd.size);

switch (nd.dtype.Name)
{
case "Int32":
Marshal.Copy(nd.Data<int>(), 0, dotHandle, nd.size);
break;
case "Single":
Marshal.Copy(nd.Data<float>(), 0, dotHandle, nd.size);
break;
case "Double":
Marshal.Copy(nd.Data<double>(), 0, dotHandle, nd.size);
break;
default:
throw new NotImplementedException("Marshal.Copy failed.");
}


var handle = c_api.TF_NewTensor(dataType,
var dataType = ToTFDataType(nd.dtype);
var tfHandle = c_api.TF_NewTensor(dataType,
nd.shape.Select(x => (long)x).ToArray(), // shape nd.shape.Select(x => (long)x).ToArray(), // shape
nd.ndim, nd.ndim,
data,
(UIntPtr)(nd.size * sizeof(float)),
dotHandle,
(UIntPtr)(nd.size * nd.dtypesize),
(IntPtr values, IntPtr len, ref bool closure) => (IntPtr values, IntPtr len, ref bool closure) =>
{ {
// Free the original buffer and set flag // Free the original buffer and set flag
Marshal.FreeHGlobal(data);
Marshal.FreeHGlobal(dotHandle);
closure = true; closure = true;
}, },
ref deallocator_called); ref deallocator_called);


this.handle = handle;
dtype = c_api.TF_TensorType(handle);
rank = c_api.TF_NumDims(handle);
bytesize = c_api.TF_TensorByteSize(handle);
buffer = c_api.TF_TensorData(handle);
dataTypeSize = c_api.TF_DataTypeSize(dtype);

shape = new long[rank];
for (int i = 0; i < rank; i++)
shape[i] = c_api.TF_Dim(handle, i);
return tfHandle;
} }


public Tensor(Operation op, int value_index, TF_DataType dtype) public Tensor(Operation op, int value_index, TF_DataType dtype)
@@ -129,11 +149,20 @@ namespace Tensorflow
{ {
switch (type.Name) switch (type.Name)
{ {
case "Int32":
return TF_DataType.TF_INT32;
case "Single": case "Single":
return TF_DataType.TF_FLOAT; return TF_DataType.TF_FLOAT;
case "Double":
return TF_DataType.TF_DOUBLE;
} }


return TF_DataType.DtInvalid; return TF_DataType.DtInvalid;
} }

public static implicit operator IntPtr(Tensor tensor)
{
return tensor.Handle;
}
} }
} }

+ 11
- 1
src/TensorFlowNET.Core/c_api.cs View File

@@ -10,12 +10,22 @@ namespace Tensorflow
/// ///
/// The API leans towards simplicity and uniformity instead of convenience /// The API leans towards simplicity and uniformity instead of convenience
/// since most usage will be by language specific wrappers. /// since most usage will be by language specific wrappers.
///
/// The params type mapping between .net and c_api
/// TF_XX** => ref IntPtr (TF_Operation** op) => (ref IntPtr op)
/// TF_XX* => IntPtr (TF_Graph* graph) => (IntPtr graph)
/// struct => struct (TF_Output output) => (TF_Output output)
/// const char* => string
/// int32_t => int
/// int64_t* => long[]
/// size_t* => unlong[]
/// void* => IntPtr
/// </summary> /// </summary>
public static partial class c_api public static partial class c_api
{ {
public const string TensorFlowLibName = "tensorflow"; public const string TensorFlowLibName = "tensorflow";


public delegate void Deallocator(IntPtr data, IntPtr size, ref bool deallocatorData);
public delegate void Deallocator(IntPtr data, IntPtr size, ref bool deallocator);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static unsafe extern IntPtr TF_Version(); public static unsafe extern IntPtr TF_Version();


+ 12
- 3
test/TensorFlowNET.UnitTest/OperationsTest.cs View File

@@ -3,12 +3,21 @@ using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using Tensorflow; using Tensorflow;
using Buffer = Tensorflow.Buffer;


namespace TensorFlowNET.UnitTest namespace TensorFlowNET.UnitTest
{ {
[TestClass] [TestClass]
public class OperationsTest public class OperationsTest
{ {
[TestMethod]
public void GetAllOpList()
{
var handle = c_api.TF_GetAllOpList();
var buffer = new Buffer(handle);
Assert.IsTrue(buffer.Length == buffer.Data.Length);
}

[TestMethod] [TestMethod]
public void addInPlaceholder() public void addInPlaceholder()
{ {
@@ -18,9 +27,9 @@ namespace TensorFlowNET.UnitTest


using(var sess = tf.Session()) using(var sess = tf.Session())
{ {
var feed_dict = new FeedDict()
.Add(a, 3.0f)
.Add(b, 2.0f);
var feed_dict = new Dictionary<Tensor, object>();
feed_dict.Add(a, 3.0f);
feed_dict.Add(b, 2.0f);


var o = sess.run(c, feed_dict); var o = sess.run(c, feed_dict);
} }


+ 65
- 1
test/TensorFlowNET.UnitTest/TensorTest.cs View File

@@ -13,7 +13,7 @@ namespace TensorFlowNET.UnitTest
public class TensorTest public class TensorTest
{ {
[TestMethod] [TestMethod]
public unsafe void NewTensor()
public void NewTensor()
{ {
var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3);


@@ -27,5 +27,69 @@ namespace TensorFlowNET.UnitTest
Assert.AreEqual(tensor.bytesize, (uint)nd.size * sizeof(float)); Assert.AreEqual(tensor.bytesize, (uint)nd.size * sizeof(float));
Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), array)); Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), array));
} }

/// <summary>
/// Port from tensorflow\c\c_api_test.cc
/// </summary>
[TestMethod]
public void SetShape()
{
var s = new Status();
var graph = tf.get_default_graph();

var desc = c_api.TF_NewOperation(graph, "Placeholder", "");
c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_FLOAT);
//if (!dims.empty())
{
//TF_SetAttrShape(desc, "shape", dims.data(), dims.size());
}
var op = c_api.TF_FinishOperation(desc, s);

Assert.IsTrue(s.Code == TF_Code.TF_OK);
Assert.IsNotNull(op);

// Fetch the shape, it should be completely unknown.
var feed_out_0 = new TF_Output { oper = op, index = 0 };
int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);

Assert.IsTrue(s.Code == TF_Code.TF_OK);
Assert.AreEqual(-1, num_dims);

// Set the shape to be unknown, expect no change.
c_api.TF_GraphSetTensorShape(graph, feed_out_0, new int[0], -1, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);
Assert.AreEqual(-1, num_dims);

// Set the shape to be 2 x Unknown
var dims = new int[] { 2, -1 };
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);
Assert.AreEqual(2, num_dims);

// Get the dimension vector appropriately.
var returned_dims = new int[dims.Length];
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims));

// Set to a new valid shape: [2, 3]
dims[1] = 3;
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s);
//Assert.IsTrue(s.Code == TF_Code.TF_OK);

// Fetch and see that the new value is returned.
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
//Assert.IsTrue(s.Code == TF_Code.TF_OK);
//Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims));

// Test for a scalar.
var three = c_test_util.ScalarConst(3, graph, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
var three_out_0 = new TF_Output { oper = three.Handle };
num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s);
Assert.AreEqual(0, num_dims);
}
} }
} }

+ 37
- 0
test/TensorFlowNET.UnitTest/c_test_util.cs View File

@@ -0,0 +1,37 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;
using Tensorflow;

namespace TensorFlowNET.UnitTest
{
public static class c_test_util
{
public static void ConstHelper(Tensor t, Graph graph, Status s, string name, ref IntPtr op)
{
var desc = c_api.TF_NewOperation(graph, "Const", name);
c_api.TF_SetAttrTensor(desc, "value", t.Handle, s);
s.Check();
c_api.TF_SetAttrType(desc, "dtype", t.dtype);
op = c_api.TF_FinishOperation(desc, s);
s.Check();
if(op == null)
{
throw new Exception("c_api.TF_FinishOperation failed.");
}
}

public static Operation Const(Tensor t, Graph graph, Status s, string name)
{
IntPtr op = IntPtr.Zero;
ConstHelper(t, graph, s, name, ref op);
return new Operation(op);
}

public static Operation ScalarConst(int v, Graph graph, Status s, string name = "Const")
{
return Const(new Tensor(v), graph, s, name);
}
}
}

Loading…
Cancel
Save