| @@ -9,12 +9,12 @@ namespace Tensorflow | |||||
| { | { | ||||
| public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false) | public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false) | ||||
| { | { | ||||
| //constant_op.Create(nd, name, verify_shape); | |||||
| var graph = tf.get_default_graph(); | |||||
| var t = constant_op.Create(nd, name, verify_shape); | |||||
| /*var graph = tf.get_default_graph(); | |||||
| var tensor = new Tensor(nd); | var tensor = new Tensor(nd); | ||||
| var op = graph.NewOperation("Const", name, tensor); | |||||
| var op = graph.NewOperation("Const", name, tensor);*/ | |||||
| return null; | |||||
| return t; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -8,22 +8,22 @@ namespace Tensorflow | |||||
| { | { | ||||
| public partial class Graph | public partial class Graph | ||||
| { | { | ||||
| public Operation NewOperation(string opType, string opName, Tensor tensor) | |||||
| public OpDef GetOpDef(string type) | |||||
| { | { | ||||
| var desc = c_api.TF_NewOperation(_handle, opType, opName); | |||||
| if (tensor.dtype == TF_DataType.TF_STRING) | |||||
| { | |||||
| var value = "Hello World!"; | |||||
| var bytes = Encoding.UTF8.GetBytes(value); | |||||
| var buf = Marshal.AllocHGlobal(bytes.Length + 1); | |||||
| Marshal.Copy(bytes, 0, buf, bytes.Length); | |||||
| c_api.TF_SetAttrString(desc, "value", buf, (uint)value.Length); | |||||
| } | |||||
| else | |||||
| using (var buffer = new Buffer()) | |||||
| using (var status = new Status()) | |||||
| { | { | ||||
| c_api.TF_SetAttrTensor(desc, "value", tensor, Status); | |||||
| c_api.TF_GraphGetOpDef(_handle, type, buffer, status); | |||||
| return OpDef.Parser.ParseFrom(buffer.Data); | |||||
| } | } | ||||
| } | |||||
| public OperationDescription NewOperation(string opType, string opName) | |||||
| { | |||||
| OperationDescription desc = c_api.TF_NewOperation(_handle, opType, opName); | |||||
| return desc; | |||||
| /*c_api.TF_SetAttrTensor(desc, "value", tensor, Status); | |||||
| Status.Check(); | Status.Check(); | ||||
| @@ -32,7 +32,7 @@ namespace Tensorflow | |||||
| var op = c_api.TF_FinishOperation(desc, Status); | var op = c_api.TF_FinishOperation(desc, Status); | ||||
| Status.Check(); | Status.Check(); | ||||
| return op; | |||||
| return op;*/ | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -106,9 +106,15 @@ namespace Tensorflow | |||||
| original_op: null, | original_op: null, | ||||
| op_def: op_def); | op_def: op_def); | ||||
| _create_op_helper(op, true); | |||||
| return op; | return op; | ||||
| } | } | ||||
| private void _create_op_helper(Operation op, bool compute_device = true) | |||||
| { | |||||
| } | |||||
| public void _add_op(Operation op) | public void _add_op(Operation op) | ||||
| { | { | ||||
| _nodes_by_id[op._id] = op; | _nodes_by_id[op._id] = op; | ||||
| @@ -109,12 +109,17 @@ namespace Tensorflow | |||||
| Graph = g; | Graph = g; | ||||
| _id_value = Graph._next_id(); | _id_value = Graph._next_id(); | ||||
| if(op_def == null) | |||||
| op_def = g.GetOpDef(node_def.Op); | |||||
| _handle = ops._create_c_op(g, node_def, inputs); | _handle = ops._create_c_op(g, node_def, inputs); | ||||
| _outputs = new Tensor[NumOutputs]; | _outputs = new Tensor[NumOutputs]; | ||||
| output_types = new TF_DataType[NumOutputs]; | |||||
| for (int i = 0; i < NumOutputs; i++) | for (int i = 0; i < NumOutputs; i++) | ||||
| { | { | ||||
| output_types[i] = OutputType(i); | |||||
| _outputs[i] = new Tensor(this, i, output_types[i]); | _outputs[i] = new Tensor(this, i, output_types[i]); | ||||
| } | } | ||||
| @@ -197,7 +197,7 @@ namespace Tensorflow | |||||
| public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, IntPtr status); | public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, IntPtr status); | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, IntPtr proto, UIntPtr proto_len, IntPtr status); | |||||
| public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, IntPtr proto, uint proto_len, IntPtr status); | |||||
| /// <summary> | /// <summary> | ||||
| /// Set `num_dims` to -1 to represent "unknown rank". | /// Set `num_dims` to -1 to represent "unknown rank". | ||||
| @@ -25,17 +25,17 @@ namespace Tensorflow | |||||
| public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs) | public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs) | ||||
| { | { | ||||
| var op_desc = c_api.TF_NewOperation(graph, node_def.Op, node_def.Name); | |||||
| var op_desc = graph.NewOperation(node_def.Op, node_def.Name); | |||||
| // Add inputs | // Add inputs | ||||
| if(inputs != null && inputs.Count > 0) | if(inputs != null && inputs.Count > 0) | ||||
| { | { | ||||
| foreach (var op_input in inputs) | |||||
| /*foreach (var op_input in inputs) | |||||
| { | { | ||||
| c_api.TF_AddInput(op_desc, op_input._as_tf_output()); | c_api.TF_AddInput(op_desc, op_input._as_tf_output()); | ||||
| } | |||||
| }*/ | |||||
| //c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count); | |||||
| c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count); | |||||
| } | } | ||||
| var status = new Status(); | var status = new Status(); | ||||
| @@ -48,9 +48,10 @@ namespace Tensorflow | |||||
| var bytes = attr.Value.ToByteArray(); | var bytes = attr.Value.ToByteArray(); | ||||
| var proto = Marshal.AllocHGlobal(bytes.Length); | var proto = Marshal.AllocHGlobal(bytes.Length); | ||||
| Marshal.Copy(bytes, 0, proto, bytes.Length); | Marshal.Copy(bytes, 0, proto, bytes.Length); | ||||
| c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status); | |||||
| c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (uint)bytes.Length, status: status); | |||||
| if(status.Code != TF_Code.TF_OK) throw new Exception(status.Message); | |||||
| status.Check(true); | |||||
| } | } | ||||
| var c_op = c_api.TF_FinishOperation(op_desc, status); | var c_op = c_api.TF_FinishOperation(op_desc, status); | ||||
| @@ -60,6 +61,11 @@ namespace Tensorflow | |||||
| return c_op; | return c_op; | ||||
| } | } | ||||
| public static OpDef _get_op_def(Graph graph, string type) | |||||
| { | |||||
| return graph.GetOpDef(type); | |||||
| } | |||||
| public static NodeDef _NodeDef(string op_type, string name, string device = "", Dictionary<string, AttrValue> attrs = null) | public static NodeDef _NodeDef(string op_type, string name, string device = "", Dictionary<string, AttrValue> attrs = null) | ||||
| { | { | ||||
| var node_def = new node_def_pb2.NodeDef(); | var node_def = new node_def_pb2.NodeDef(); | ||||
| @@ -115,12 +115,27 @@ namespace Tensorflow | |||||
| run_metadata: IntPtr.Zero, | run_metadata: IntPtr.Zero, | ||||
| status: status); | status: status); | ||||
| var result = output_values.Select(x => c_api.TF_TensorData(x)) | |||||
| .Select(x => (object)*(float*)x) | |||||
| .ToArray(); | |||||
| object[] result = new object[fetch_list.Length]; | |||||
| var op = new Operation(fetch_list[0].oper); | |||||
| //var metadata = c_api.TF_OperationGetAttrMetadata(fetch_list[0].oper, "dtype", status); | |||||
| for (int i = 0; i < fetch_list.Length; i++) | |||||
| { | |||||
| var tensor = new Tensor(output_values[i]); | |||||
| switch (tensor.dtype) | |||||
| { | |||||
| case TF_DataType.TF_STRING: | |||||
| // wired, don't know why we have to start from offset 9. | |||||
| var bytes = tensor.Data(); | |||||
| result[i] = UTF8Encoding.Default.GetString(bytes, 9, bytes.Length - 9); | |||||
| break; | |||||
| case TF_DataType.TF_FLOAT: | |||||
| result[i] = *(float*)c_api.TF_TensorData(output_values[i]); | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException("can't get output"); | |||||
| break; | |||||
| } | |||||
| } | |||||
| return result; | return result; | ||||
| } | } | ||||
| @@ -8,7 +8,7 @@ namespace Tensorflow | |||||
| /// TF_Status holds error information. It either has an OK code, or | /// TF_Status holds error information. It either has an OK code, or | ||||
| /// else an error code with an associated error message. | /// else an error code with an associated error message. | ||||
| /// </summary> | /// </summary> | ||||
| public class Status | |||||
| public class Status : IDisposable | |||||
| { | { | ||||
| private readonly IntPtr _handle; | private readonly IntPtr _handle; | ||||
| @@ -146,7 +146,6 @@ namespace Tensorflow | |||||
| { | { | ||||
| var data = new byte[bytesize]; | var data = new byte[bytesize]; | ||||
| Marshal.Copy(buffer, data, 0, (int)bytesize); | Marshal.Copy(buffer, data, 0, (int)bytesize); | ||||
| return data; | return data; | ||||
| } | } | ||||
| @@ -22,17 +22,21 @@ namespace Tensorflow | |||||
| public static Tensor Create(NDArray nd, string name = "Const", bool verify_shape = false) | public static Tensor Create(NDArray nd, string name = "Const", bool verify_shape = false) | ||||
| { | { | ||||
| Graph g = ops.get_default_graph(); | Graph g = ops.get_default_graph(); | ||||
| var tensor_value = new AttrValue(); | |||||
| var tensor_pb = tensor_util.make_tensor_proto(nd, verify_shape); | var tensor_pb = tensor_util.make_tensor_proto(nd, verify_shape); | ||||
| tensor_value.Tensor = tensor_pb; | |||||
| var tensor_value = new AttrValue | |||||
| { | |||||
| Type = tensor_pb.Dtype, | |||||
| Tensor = tensor_pb | |||||
| }; | |||||
| var dtype_value = new AttrValue | var dtype_value = new AttrValue | ||||
| { | { | ||||
| Type = tensor_value.Tensor.Dtype, | Type = tensor_value.Tensor.Dtype, | ||||
| }; | }; | ||||
| var attrs = new Dictionary<string, AttrValue>(); | var attrs = new Dictionary<string, AttrValue>(); | ||||
| attrs["dtype"] = dtype_value; | |||||
| attrs["value"] = tensor_value; | attrs["value"] = tensor_value; | ||||
| attrs["dtype"] = dtype_value; | |||||
| var op = g.create_op("Const", | var op = g.create_op("Const", | ||||
| null, | null, | ||||
| @@ -2,6 +2,7 @@ | |||||
| using NumSharp.Core.Interfaces; | using NumSharp.Core.Interfaces; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| using tensor_pb2 = Tensorflow; | using tensor_pb2 = Tensorflow; | ||||
| @@ -32,7 +33,7 @@ namespace Tensorflow | |||||
| tensor_proto.DoubleVal.AddRange(nd.Data<double>()); | tensor_proto.DoubleVal.AddRange(nd.Data<double>()); | ||||
| break; | break; | ||||
| case "String": | case "String": | ||||
| tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFrom(nd.Data<string>()[0], Encoding.UTF8)); | |||||
| tensor_proto.StringVal.AddRange(nd.Data<string>().Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x))); | |||||
| break; | break; | ||||
| default: | default: | ||||
| throw new Exception("Not Implemented"); | throw new Exception("Not Implemented"); | ||||