| @@ -24,23 +24,36 @@ namespace Tensorflow | |||||
| private List<String> _unfetchable_ops = new List<string>(); | private List<String> _unfetchable_ops = new List<string>(); | ||||
| private string _name_stack; | private string _name_stack; | ||||
| public Status Status { get; } | |||||
| public Graph() | public Graph() | ||||
| { | { | ||||
| _handle = c_api.TF_NewGraph(); | _handle = c_api.TF_NewGraph(); | ||||
| Status = new Status(); | |||||
| } | } | ||||
| public Graph(IntPtr graph) | public Graph(IntPtr graph) | ||||
| { | { | ||||
| _handle = graph; | _handle = graph; | ||||
| Status = new Status(); | |||||
| _nodes_by_id = new Dictionary<int, Operation>(); | _nodes_by_id = new Dictionary<int, Operation>(); | ||||
| _nodes_by_name = new Dictionary<string, Operation>(); | _nodes_by_name = new Dictionary<string, Operation>(); | ||||
| _names_in_use = new Dictionary<string, int>(); | _names_in_use = new Dictionary<string, int>(); | ||||
| } | } | ||||
| public OperationDescription NewOperation(string opType, string opName) | |||||
| public Operation NewOperation(string opType, string opName, Tensor t) | |||||
| { | { | ||||
| return c_api.TF_NewOperation(_handle, opType, opName); | |||||
| var desc = c_api.TF_NewOperation(_handle, opType, opName); | |||||
| c_api.TF_SetAttrTensor(desc, "value", t, Status); | |||||
| Status.Check(); | |||||
| c_api.TF_SetAttrType(desc, "dtype", t.dtype); | |||||
| var op = c_api.TF_FinishOperation(desc, Status); | |||||
| Status.Check(); | |||||
| return op; | |||||
| } | } | ||||
| public T as_graph_element<T>(T obj, bool allow_tensor = true, bool allow_operation = true) | public T as_graph_element<T>(T obj, bool allow_tensor = true, bool allow_operation = true) | ||||
| @@ -7,6 +7,18 @@ namespace Tensorflow | |||||
| { | { | ||||
| public static partial class c_api | public static partial class c_api | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Request that `desc` be co-located on the device where `op` | |||||
| /// is placed. | |||||
| /// | |||||
| /// Use of this is discouraged since the implementation of device placement is | |||||
| /// subject to change. Primarily intended for internal libraries | |||||
| /// </summary> | |||||
| /// <param name="desc"></param> | |||||
| /// <param name="op"></param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TF_ColocateWith(IntPtr desc, IntPtr op); | |||||
| /// <summary> | /// <summary> | ||||
| /// Get the OpList of all OpDefs defined in this address space. | /// Get the OpList of all OpDefs defined in this address space. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -209,7 +221,7 @@ namespace Tensorflow | |||||
| /// <param name="value">const void*</param> | /// <param name="value">const void*</param> | ||||
| /// <param name="length">size_t</param> | /// <param name="length">size_t</param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_SetAttrString(IntPtr desc, string attr_name, string value, uint length); | |||||
| public static extern void TF_SetAttrString(IntPtr desc, string attr_name, IntPtr value, uint length); | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| @@ -119,6 +119,9 @@ namespace Tensorflow | |||||
| .Select(x => (object)*(float*)x) | .Select(x => (object)*(float*)x) | ||||
| .ToArray(); | .ToArray(); | ||||
| var op = new Operation(fetch_list[0].oper); | |||||
| //var metadata = c_api.TF_OperationGetAttrMetadata(fetch_list[0].oper, "dtype", status); | |||||
| return result; | return result; | ||||
| } | } | ||||
| @@ -7,9 +7,19 @@ namespace Tensorflow | |||||
| public class Session : BaseSession | public class Session : BaseSession | ||||
| { | { | ||||
| private IntPtr _handle; | private IntPtr _handle; | ||||
| public Status Status { get; } | |||||
| public SessionOptions Options { get; } | |||||
| public Session(string target = "", Graph graph = null) | public Session(string target = "", Graph graph = null) | ||||
| { | { | ||||
| Status = new Status(); | |||||
| if(graph == null) | |||||
| { | |||||
| graph = tf.get_default_graph(); | |||||
| } | |||||
| Options = new SessionOptions(); | |||||
| _handle = c_api.TF_NewSession(graph, Options, Status); | |||||
| Status.Check(); | |||||
| } | } | ||||
| public Session(IntPtr handle) | public Session(IntPtr handle) | ||||
| @@ -36,12 +36,15 @@ namespace Tensorflow | |||||
| /// Check status | /// Check status | ||||
| /// Throw exception with error message if code != TF_OK | /// Throw exception with error message if code != TF_OK | ||||
| /// </summary> | /// </summary> | ||||
| public void Check() | |||||
| public void Check(bool throwException = false) | |||||
| { | { | ||||
| if(Code != TF_Code.TF_OK) | if(Code != TF_Code.TF_OK) | ||||
| { | { | ||||
| Console.WriteLine(Message); | Console.WriteLine(Message); | ||||
| // throw new Exception(Message); | |||||
| if (throwException) | |||||
| { | |||||
| throw new Exception(Message); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -69,6 +69,7 @@ namespace Tensorflow | |||||
| private IntPtr Allocate(NDArray nd) | private IntPtr Allocate(NDArray nd) | ||||
| { | { | ||||
| var dotHandle = Marshal.AllocHGlobal(nd.dtypesize * nd.size); | var dotHandle = Marshal.AllocHGlobal(nd.dtypesize * nd.size); | ||||
| ulong size = (ulong)(nd.size * nd.dtypesize); | |||||
| switch (nd.dtype.Name) | switch (nd.dtype.Name) | ||||
| { | { | ||||
| @@ -81,16 +82,21 @@ namespace Tensorflow | |||||
| case "Double": | case "Double": | ||||
| Marshal.Copy(nd.Data<double>(), 0, dotHandle, nd.size); | Marshal.Copy(nd.Data<double>(), 0, dotHandle, nd.size); | ||||
| break; | break; | ||||
| case "String": | |||||
| dotHandle = Marshal.StringToHGlobalAuto(nd.Data<string>()[0]); | |||||
| size = (ulong)nd.Data<string>()[0].Length; | |||||
| break; | |||||
| default: | default: | ||||
| throw new NotImplementedException("Marshal.Copy failed."); | throw new NotImplementedException("Marshal.Copy failed."); | ||||
| } | } | ||||
| var dataType = ToTFDataType(nd.dtype); | var dataType = ToTFDataType(nd.dtype); | ||||
| var tfHandle = c_api.TF_NewTensor(dataType, | var tfHandle = c_api.TF_NewTensor(dataType, | ||||
| nd.shape.Select(x => (long)x).ToArray(), // shape | nd.shape.Select(x => (long)x).ToArray(), // shape | ||||
| nd.ndim, | nd.ndim, | ||||
| dotHandle, | dotHandle, | ||||
| (ulong)(nd.size * nd.dtypesize), | |||||
| size, | |||||
| (IntPtr values, IntPtr len, ref bool closure) => | (IntPtr values, IntPtr len, ref bool closure) => | ||||
| { | { | ||||
| // Free the original buffer and set flag | // Free the original buffer and set flag | ||||
| @@ -154,6 +160,8 @@ namespace Tensorflow | |||||
| return TF_DataType.TF_FLOAT; | return TF_DataType.TF_FLOAT; | ||||
| case "Double": | case "Double": | ||||
| return TF_DataType.TF_DOUBLE; | return TF_DataType.TF_DOUBLE; | ||||
| case "String": | |||||
| return TF_DataType.TF_STRING; | |||||
| } | } | ||||
| return TF_DataType.DtInvalid; | return TF_DataType.DtInvalid; | ||||
| @@ -34,12 +34,13 @@ namespace Tensorflow | |||||
| attrs["dtype"] = dtype_value; | attrs["dtype"] = dtype_value; | ||||
| attrs["value"] = tensor_value; | attrs["value"] = tensor_value; | ||||
| var const_tensor = g.create_op("Const", | |||||
| null, | |||||
| new TF_DataType[] { (TF_DataType)dtype_value.Type }, | |||||
| var op = g.create_op("Const", | |||||
| null, | |||||
| new TF_DataType[] { (TF_DataType)dtype_value.Type }, | |||||
| attrs: attrs, | attrs: attrs, | ||||
| name: name).outputs[0]; | |||||
| name: name); | |||||
| var const_tensor = op.outputs[0]; | |||||
| const_tensor.value = nd.Data(); | const_tensor.value = nd.Data(); | ||||
| return const_tensor; | return const_tensor; | ||||
| @@ -7,9 +7,9 @@ namespace Tensorflow | |||||
| { | { | ||||
| public static partial class tf | public static partial class tf | ||||
| { | { | ||||
| public static Tensor constant(NDArray value, string name = "Const", bool verify_shape = false) | |||||
| public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false) | |||||
| { | { | ||||
| return constant_op.Create(value, name, verify_shape); | |||||
| return constant_op.Create(nd, name, verify_shape); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -24,7 +24,7 @@ namespace TensorFlowNET.Examples | |||||
| var sess = tf.Session(); | var sess = tf.Session(); | ||||
| // Run the op | // Run the op | ||||
| sess.run(hello); | |||||
| Console.WriteLine(sess.run(hello)); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -23,6 +23,8 @@ namespace TensorFlowNET.Examples | |||||
| Console.ReadLine(); | Console.ReadLine(); | ||||
| } | } | ||||
| } | } | ||||
| Console.ReadLine(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -25,12 +25,17 @@ namespace TensorFlowNET.UnitTest | |||||
| public void SetUp() | public void SetUp() | ||||
| { | { | ||||
| feed1_ = c_test_util.Placeholder(graph_, s_, "feed1"); | feed1_ = c_test_util.Placeholder(graph_, s_, "feed1"); | ||||
| s_.Check(); | |||||
| feed2_ = c_test_util.Placeholder(graph_, s_, "feed2"); | feed2_ = c_test_util.Placeholder(graph_, s_, "feed2"); | ||||
| s_.Check(); | |||||
| constant_ = c_test_util.ScalarConst(10, graph_, s_); | constant_ = c_test_util.ScalarConst(10, graph_, s_); | ||||
| desc_ = graph_.NewOperation("AddN", "add"); | |||||
| s_.Check(); | |||||
| desc_ = c_api.TF_NewOperation(graph_, "AddN", "add"); | |||||
| s_.Check(); | |||||
| TF_Output[] inputs = { new TF_Output(feed1_, 0), new TF_Output(constant_, 0) }; | TF_Output[] inputs = { new TF_Output(feed1_, 0), new TF_Output(constant_, 0) }; | ||||
| desc_.AddInputList(inputs); | desc_.AddInputList(inputs); | ||||
| s_.Check(); | |||||
| } | } | ||||
| private void SetViaStringList(OperationDescription desc, string[] list) | private void SetViaStringList(OperationDescription desc, string[] list) | ||||
| @@ -85,7 +90,8 @@ namespace TensorFlowNET.UnitTest | |||||
| [TestMethod] | [TestMethod] | ||||
| public void ColocateWith() | public void ColocateWith() | ||||
| { | { | ||||
| c_api.TF_ColocateWith(desc_, feed1_); | |||||
| FinishAndVerify(desc_, new string[] { "loc:@feed1" }); | |||||
| } | } | ||||
| [TestMethod] | [TestMethod] | ||||