diff --git a/src/TensorFlowNET.Core/APIs/tf.constant.cs b/src/TensorFlowNET.Core/APIs/tf.constant.cs index 3b44c13d..b43d611d 100644 --- a/src/TensorFlowNET.Core/APIs/tf.constant.cs +++ b/src/TensorFlowNET.Core/APIs/tf.constant.cs @@ -9,12 +9,12 @@ namespace Tensorflow { public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false) { - //constant_op.Create(nd, name, verify_shape); - var graph = tf.get_default_graph(); + var t = constant_op.Create(nd, name, verify_shape); + /*var graph = tf.get_default_graph(); var tensor = new Tensor(nd); - var op = graph.NewOperation("Const", name, tensor); + var op = graph.NewOperation("Const", name, tensor);*/ - return null; + return t; } } } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs index 8cd202e4..ca726d5b 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs @@ -8,22 +8,22 @@ namespace Tensorflow { public partial class Graph { - public Operation NewOperation(string opType, string opName, Tensor tensor) + public OpDef GetOpDef(string type) { - var desc = c_api.TF_NewOperation(_handle, opType, opName); - - if (tensor.dtype == TF_DataType.TF_STRING) - { - var value = "Hello World!"; - var bytes = Encoding.UTF8.GetBytes(value); - var buf = Marshal.AllocHGlobal(bytes.Length + 1); - Marshal.Copy(bytes, 0, buf, bytes.Length); - c_api.TF_SetAttrString(desc, "value", buf, (uint)value.Length); - } - else + using (var buffer = new Buffer()) + using (var status = new Status()) { - c_api.TF_SetAttrTensor(desc, "value", tensor, Status); + c_api.TF_GraphGetOpDef(_handle, type, buffer, status); + return OpDef.Parser.ParseFrom(buffer.Data); } + } + + public OperationDescription NewOperation(string opType, string opName) + { + OperationDescription desc = c_api.TF_NewOperation(_handle, opType, opName); + return desc; + + /*c_api.TF_SetAttrTensor(desc, "value", tensor, Status); Status.Check(); @@ -32,7 +32,7 @@ namespace Tensorflow var op = c_api.TF_FinishOperation(desc, Status); Status.Check(); - return op; + return op;*/ } } } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 8f50f37b..9cae49fc 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -106,9 +106,15 @@ namespace Tensorflow original_op: null, op_def: op_def); + _create_op_helper(op, true); return op; } + private void _create_op_helper(Operation op, bool compute_device = true) + { + + } + public void _add_op(Operation op) { _nodes_by_id[op._id] = op; diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index d8a77dd1..0d3d8a93 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -109,12 +109,17 @@ namespace Tensorflow Graph = g; _id_value = Graph._next_id(); + if(op_def == null) + op_def = g.GetOpDef(node_def.Op); _handle = ops._create_c_op(g, node_def, inputs); - + _outputs = new Tensor[NumOutputs]; + output_types = new TF_DataType[NumOutputs]; + for (int i = 0; i < NumOutputs; i++) { + output_types[i] = OutputType(i); _outputs[i] = new Tensor(this, i, output_types[i]); } diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs index 262fe531..7bbd3088 100644 --- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs @@ -197,7 +197,7 @@ namespace Tensorflow public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, IntPtr status); [DllImport(TensorFlowLibName)] - public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, IntPtr proto, UIntPtr proto_len, IntPtr status); + public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, IntPtr proto, uint proto_len, IntPtr status); /// /// Set `num_dims` to -1 to represent "unknown rank". diff --git a/src/TensorFlowNET.Core/Operations/ops.cs b/src/TensorFlowNET.Core/Operations/ops.cs index 85ef04ea..3344d5af 100644 --- a/src/TensorFlowNET.Core/Operations/ops.cs +++ b/src/TensorFlowNET.Core/Operations/ops.cs @@ -25,17 +25,17 @@ namespace Tensorflow public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List inputs) { - var op_desc = c_api.TF_NewOperation(graph, node_def.Op, node_def.Name); + var op_desc = graph.NewOperation(node_def.Op, node_def.Name); // Add inputs if(inputs != null && inputs.Count > 0) { - foreach (var op_input in inputs) + /*foreach (var op_input in inputs) { c_api.TF_AddInput(op_desc, op_input._as_tf_output()); - } + }*/ - //c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count); + c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count); } var status = new Status(); @@ -48,9 +48,10 @@ namespace Tensorflow var bytes = attr.Value.ToByteArray(); var proto = Marshal.AllocHGlobal(bytes.Length); Marshal.Copy(bytes, 0, proto, bytes.Length); - c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status); + + c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (uint)bytes.Length, status: status); - if(status.Code != TF_Code.TF_OK) throw new Exception(status.Message); + status.Check(true); } var c_op = c_api.TF_FinishOperation(op_desc, status); @@ -60,6 +61,11 @@ namespace Tensorflow return c_op; } + public static OpDef _get_op_def(Graph graph, string type) + { + return graph.GetOpDef(type); + } + public static NodeDef _NodeDef(string op_type, string name, string device = "", Dictionary attrs = null) { var node_def = new node_def_pb2.NodeDef(); diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index d9739f7b..85343ba8 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -115,12 +115,27 @@ namespace Tensorflow run_metadata: IntPtr.Zero, status: status); - var result = output_values.Select(x => c_api.TF_TensorData(x)) - .Select(x => (object)*(float*)x) - .ToArray(); + object[] result = new object[fetch_list.Length]; - var op = new Operation(fetch_list[0].oper); - //var metadata = c_api.TF_OperationGetAttrMetadata(fetch_list[0].oper, "dtype", status); + for (int i = 0; i < fetch_list.Length; i++) + { + var tensor = new Tensor(output_values[i]); + + switch (tensor.dtype) + { + case TF_DataType.TF_STRING: + // wired, don't know why we have to start from offset 9. + var bytes = tensor.Data(); + result[i] = UTF8Encoding.Default.GetString(bytes, 9, bytes.Length - 9); + break; + case TF_DataType.TF_FLOAT: + result[i] = *(float*)c_api.TF_TensorData(output_values[i]); + break; + default: + throw new NotImplementedException("can't get output"); + break; + } + } return result; } diff --git a/src/TensorFlowNET.Core/Status/Status.cs b/src/TensorFlowNET.Core/Status/Status.cs index a4648307..c93304c7 100644 --- a/src/TensorFlowNET.Core/Status/Status.cs +++ b/src/TensorFlowNET.Core/Status/Status.cs @@ -8,7 +8,7 @@ namespace Tensorflow /// TF_Status holds error information. It either has an OK code, or /// else an error code with an associated error message. /// - public class Status + public class Status : IDisposable { private readonly IntPtr _handle; diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 3fa6ca5d..a6765a8f 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -146,7 +146,6 @@ namespace Tensorflow { var data = new byte[bytesize]; Marshal.Copy(buffer, data, 0, (int)bytesize); - return data; } diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index 3aa643d0..5298f78c 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -22,17 +22,21 @@ namespace Tensorflow public static Tensor Create(NDArray nd, string name = "Const", bool verify_shape = false) { Graph g = ops.get_default_graph(); - var tensor_value = new AttrValue(); var tensor_pb = tensor_util.make_tensor_proto(nd, verify_shape); - tensor_value.Tensor = tensor_pb; + var tensor_value = new AttrValue + { + Type = tensor_pb.Dtype, + Tensor = tensor_pb + }; + var dtype_value = new AttrValue { Type = tensor_value.Tensor.Dtype, }; var attrs = new Dictionary(); - attrs["dtype"] = dtype_value; attrs["value"] = tensor_value; + attrs["dtype"] = dtype_value; var op = g.create_op("Const", null, diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 052adee6..d1fe6b97 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -2,6 +2,7 @@ using NumSharp.Core.Interfaces; using System; using System.Collections.Generic; +using System.Linq; using System.Text; using tensor_pb2 = Tensorflow; @@ -32,7 +33,7 @@ namespace Tensorflow tensor_proto.DoubleVal.AddRange(nd.Data()); break; case "String": - tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFrom(nd.Data()[0], Encoding.UTF8)); + tensor_proto.StringVal.AddRange(nd.Data().Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x))); break; default: throw new Exception("Not Implemented");