diff --git a/src/TensorFlowNET.Core/APIs/tf.constant.cs b/src/TensorFlowNET.Core/APIs/tf.constant.cs
index 3b44c13d..b43d611d 100644
--- a/src/TensorFlowNET.Core/APIs/tf.constant.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.constant.cs
@@ -9,12 +9,12 @@ namespace Tensorflow
{
public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false)
{
- //constant_op.Create(nd, name, verify_shape);
- var graph = tf.get_default_graph();
+ var t = constant_op.Create(nd, name, verify_shape);
+ /*var graph = tf.get_default_graph();
var tensor = new Tensor(nd);
- var op = graph.NewOperation("Const", name, tensor);
+ var op = graph.NewOperation("Const", name, tensor);*/
- return null;
+ return t;
}
}
}
diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
index 8cd202e4..ca726d5b 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
@@ -8,22 +8,22 @@ namespace Tensorflow
{
public partial class Graph
{
- public Operation NewOperation(string opType, string opName, Tensor tensor)
+ public OpDef GetOpDef(string type)
{
- var desc = c_api.TF_NewOperation(_handle, opType, opName);
-
- if (tensor.dtype == TF_DataType.TF_STRING)
- {
- var value = "Hello World!";
- var bytes = Encoding.UTF8.GetBytes(value);
- var buf = Marshal.AllocHGlobal(bytes.Length + 1);
- Marshal.Copy(bytes, 0, buf, bytes.Length);
- c_api.TF_SetAttrString(desc, "value", buf, (uint)value.Length);
- }
- else
+ using (var buffer = new Buffer())
+ using (var status = new Status())
{
- c_api.TF_SetAttrTensor(desc, "value", tensor, Status);
+ c_api.TF_GraphGetOpDef(_handle, type, buffer, status);
+ return OpDef.Parser.ParseFrom(buffer.Data);
}
+ }
+
+ public OperationDescription NewOperation(string opType, string opName)
+ {
+ OperationDescription desc = c_api.TF_NewOperation(_handle, opType, opName);
+ return desc;
+
+ /*c_api.TF_SetAttrTensor(desc, "value", tensor, Status);
Status.Check();
@@ -32,7 +32,7 @@ namespace Tensorflow
var op = c_api.TF_FinishOperation(desc, Status);
Status.Check();
- return op;
+ return op;*/
}
}
}
diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs
index 8f50f37b..9cae49fc 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.cs
@@ -106,9 +106,15 @@ namespace Tensorflow
original_op: null,
op_def: op_def);
+ _create_op_helper(op, true);
return op;
}
+ private void _create_op_helper(Operation op, bool compute_device = true)
+ {
+
+ }
+
public void _add_op(Operation op)
{
_nodes_by_id[op._id] = op;
diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs
index d8a77dd1..0d3d8a93 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.cs
@@ -109,12 +109,17 @@ namespace Tensorflow
Graph = g;
_id_value = Graph._next_id();
+ if(op_def == null)
+ op_def = g.GetOpDef(node_def.Op);
_handle = ops._create_c_op(g, node_def, inputs);
-
+
_outputs = new Tensor[NumOutputs];
+ output_types = new TF_DataType[NumOutputs];
+
for (int i = 0; i < NumOutputs; i++)
{
+ output_types[i] = OutputType(i);
_outputs[i] = new Tensor(this, i, output_types[i]);
}
diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs
index 262fe531..7bbd3088 100644
--- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs
+++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs
@@ -197,7 +197,7 @@ namespace Tensorflow
public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, IntPtr status);
[DllImport(TensorFlowLibName)]
- public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, IntPtr proto, UIntPtr proto_len, IntPtr status);
+ public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, IntPtr proto, uint proto_len, IntPtr status);
///
/// Set `num_dims` to -1 to represent "unknown rank".
diff --git a/src/TensorFlowNET.Core/Operations/ops.cs b/src/TensorFlowNET.Core/Operations/ops.cs
index 85ef04ea..3344d5af 100644
--- a/src/TensorFlowNET.Core/Operations/ops.cs
+++ b/src/TensorFlowNET.Core/Operations/ops.cs
@@ -25,17 +25,17 @@ namespace Tensorflow
public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List inputs)
{
- var op_desc = c_api.TF_NewOperation(graph, node_def.Op, node_def.Name);
+ var op_desc = graph.NewOperation(node_def.Op, node_def.Name);
// Add inputs
if(inputs != null && inputs.Count > 0)
{
- foreach (var op_input in inputs)
+ /*foreach (var op_input in inputs)
{
c_api.TF_AddInput(op_desc, op_input._as_tf_output());
- }
+ }*/
- //c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count);
+ c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count);
}
var status = new Status();
@@ -48,9 +48,10 @@ namespace Tensorflow
var bytes = attr.Value.ToByteArray();
var proto = Marshal.AllocHGlobal(bytes.Length);
Marshal.Copy(bytes, 0, proto, bytes.Length);
- c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status);
+
+ c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (uint)bytes.Length, status: status);
- if(status.Code != TF_Code.TF_OK) throw new Exception(status.Message);
+ status.Check(true);
}
var c_op = c_api.TF_FinishOperation(op_desc, status);
@@ -60,6 +61,11 @@ namespace Tensorflow
return c_op;
}
+ public static OpDef _get_op_def(Graph graph, string type)
+ {
+ return graph.GetOpDef(type);
+ }
+
public static NodeDef _NodeDef(string op_type, string name, string device = "", Dictionary attrs = null)
{
var node_def = new node_def_pb2.NodeDef();
diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs
index d9739f7b..85343ba8 100644
--- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs
+++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs
@@ -115,12 +115,27 @@ namespace Tensorflow
run_metadata: IntPtr.Zero,
status: status);
- var result = output_values.Select(x => c_api.TF_TensorData(x))
- .Select(x => (object)*(float*)x)
- .ToArray();
+ object[] result = new object[fetch_list.Length];
- var op = new Operation(fetch_list[0].oper);
- //var metadata = c_api.TF_OperationGetAttrMetadata(fetch_list[0].oper, "dtype", status);
+ for (int i = 0; i < fetch_list.Length; i++)
+ {
+ var tensor = new Tensor(output_values[i]);
+
+ switch (tensor.dtype)
+ {
+ case TF_DataType.TF_STRING:
+ // wired, don't know why we have to start from offset 9.
+ var bytes = tensor.Data();
+ result[i] = UTF8Encoding.Default.GetString(bytes, 9, bytes.Length - 9);
+ break;
+ case TF_DataType.TF_FLOAT:
+ result[i] = *(float*)c_api.TF_TensorData(output_values[i]);
+ break;
+ default:
+ throw new NotImplementedException("can't get output");
+ break;
+ }
+ }
return result;
}
diff --git a/src/TensorFlowNET.Core/Status/Status.cs b/src/TensorFlowNET.Core/Status/Status.cs
index a4648307..c93304c7 100644
--- a/src/TensorFlowNET.Core/Status/Status.cs
+++ b/src/TensorFlowNET.Core/Status/Status.cs
@@ -8,7 +8,7 @@ namespace Tensorflow
/// TF_Status holds error information. It either has an OK code, or
/// else an error code with an associated error message.
///
- public class Status
+ public class Status : IDisposable
{
private readonly IntPtr _handle;
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs
index 3fa6ca5d..a6765a8f 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs
@@ -146,7 +146,6 @@ namespace Tensorflow
{
var data = new byte[bytesize];
Marshal.Copy(buffer, data, 0, (int)bytesize);
-
return data;
}
diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs
index 3aa643d0..5298f78c 100644
--- a/src/TensorFlowNET.Core/Tensors/constant_op.cs
+++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs
@@ -22,17 +22,21 @@ namespace Tensorflow
public static Tensor Create(NDArray nd, string name = "Const", bool verify_shape = false)
{
Graph g = ops.get_default_graph();
- var tensor_value = new AttrValue();
var tensor_pb = tensor_util.make_tensor_proto(nd, verify_shape);
- tensor_value.Tensor = tensor_pb;
+ var tensor_value = new AttrValue
+ {
+ Type = tensor_pb.Dtype,
+ Tensor = tensor_pb
+ };
+
var dtype_value = new AttrValue
{
Type = tensor_value.Tensor.Dtype,
};
var attrs = new Dictionary();
- attrs["dtype"] = dtype_value;
attrs["value"] = tensor_value;
+ attrs["dtype"] = dtype_value;
var op = g.create_op("Const",
null,
diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
index 052adee6..d1fe6b97 100644
--- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs
+++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
@@ -2,6 +2,7 @@
using NumSharp.Core.Interfaces;
using System;
using System.Collections.Generic;
+using System.Linq;
using System.Text;
using tensor_pb2 = Tensorflow;
@@ -32,7 +33,7 @@ namespace Tensorflow
tensor_proto.DoubleVal.AddRange(nd.Data());
break;
case "String":
- tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFrom(nd.Data()[0], Encoding.UTF8));
+ tensor_proto.StringVal.AddRange(nd.Data().Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x)));
break;
default:
throw new Exception("Not Implemented");