Browse Source

finished HelloWorld example.

tags/v0.1.0-Tensor
haiping008 6 years ago
parent
commit
ccce43861f
11 changed files with 73 additions and 37 deletions
  1. +4
    -4
      src/TensorFlowNET.Core/APIs/tf.constant.cs
  2. +14
    -14
      src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
  3. +6
    -0
      src/TensorFlowNET.Core/Graphs/Graph.cs
  4. +6
    -1
      src/TensorFlowNET.Core/Operations/Operation.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Operations/c_api.ops.cs
  6. +12
    -6
      src/TensorFlowNET.Core/Operations/ops.cs
  7. +20
    -5
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  8. +1
    -1
      src/TensorFlowNET.Core/Status/Status.cs
  9. +0
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  10. +7
    -3
      src/TensorFlowNET.Core/Tensors/constant_op.cs
  11. +2
    -1
      src/TensorFlowNET.Core/Tensors/tensor_util.cs

+ 4
- 4
src/TensorFlowNET.Core/APIs/tf.constant.cs View File

@@ -9,12 +9,12 @@ namespace Tensorflow
{ {
public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false) public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false)
{ {
//constant_op.Create(nd, name, verify_shape);
var graph = tf.get_default_graph();
var t = constant_op.Create(nd, name, verify_shape);
/*var graph = tf.get_default_graph();
var tensor = new Tensor(nd); var tensor = new Tensor(nd);
var op = graph.NewOperation("Const", name, tensor);
var op = graph.NewOperation("Const", name, tensor);*/
return null;
return t;
} }
} }
} }

+ 14
- 14
src/TensorFlowNET.Core/Graphs/Graph.Operation.cs View File

@@ -8,22 +8,22 @@ namespace Tensorflow
{ {
public partial class Graph public partial class Graph
{ {
public Operation NewOperation(string opType, string opName, Tensor tensor)
public OpDef GetOpDef(string type)
{ {
var desc = c_api.TF_NewOperation(_handle, opType, opName);

if (tensor.dtype == TF_DataType.TF_STRING)
{
var value = "Hello World!";
var bytes = Encoding.UTF8.GetBytes(value);
var buf = Marshal.AllocHGlobal(bytes.Length + 1);
Marshal.Copy(bytes, 0, buf, bytes.Length);
c_api.TF_SetAttrString(desc, "value", buf, (uint)value.Length);
}
else
using (var buffer = new Buffer())
using (var status = new Status())
{ {
c_api.TF_SetAttrTensor(desc, "value", tensor, Status);
c_api.TF_GraphGetOpDef(_handle, type, buffer, status);
return OpDef.Parser.ParseFrom(buffer.Data);
} }
}

public OperationDescription NewOperation(string opType, string opName)
{
OperationDescription desc = c_api.TF_NewOperation(_handle, opType, opName);
return desc;

/*c_api.TF_SetAttrTensor(desc, "value", tensor, Status);
Status.Check(); Status.Check();


@@ -32,7 +32,7 @@ namespace Tensorflow
var op = c_api.TF_FinishOperation(desc, Status); var op = c_api.TF_FinishOperation(desc, Status);
Status.Check(); Status.Check();


return op;
return op;*/
} }
} }
} }

+ 6
- 0
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -106,9 +106,15 @@ namespace Tensorflow
original_op: null, original_op: null,
op_def: op_def); op_def: op_def);


_create_op_helper(op, true);
return op; return op;
} }


private void _create_op_helper(Operation op, bool compute_device = true)
{

}

public void _add_op(Operation op) public void _add_op(Operation op)
{ {
_nodes_by_id[op._id] = op; _nodes_by_id[op._id] = op;


+ 6
- 1
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -109,12 +109,17 @@ namespace Tensorflow
Graph = g; Graph = g;


_id_value = Graph._next_id(); _id_value = Graph._next_id();
if(op_def == null)
op_def = g.GetOpDef(node_def.Op);


_handle = ops._create_c_op(g, node_def, inputs); _handle = ops._create_c_op(g, node_def, inputs);
_outputs = new Tensor[NumOutputs]; _outputs = new Tensor[NumOutputs];
output_types = new TF_DataType[NumOutputs];

for (int i = 0; i < NumOutputs; i++) for (int i = 0; i < NumOutputs; i++)
{ {
output_types[i] = OutputType(i);
_outputs[i] = new Tensor(this, i, output_types[i]); _outputs[i] = new Tensor(this, i, output_types[i]);
} }




+ 1
- 1
src/TensorFlowNET.Core/Operations/c_api.ops.cs View File

@@ -197,7 +197,7 @@ namespace Tensorflow
public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, IntPtr status); public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, IntPtr status);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, IntPtr proto, UIntPtr proto_len, IntPtr status);
public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, IntPtr proto, uint proto_len, IntPtr status);


/// <summary> /// <summary>
/// Set `num_dims` to -1 to represent "unknown rank". /// Set `num_dims` to -1 to represent "unknown rank".


+ 12
- 6
src/TensorFlowNET.Core/Operations/ops.cs View File

@@ -25,17 +25,17 @@ namespace Tensorflow


public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs) public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs)
{ {
var op_desc = c_api.TF_NewOperation(graph, node_def.Op, node_def.Name);
var op_desc = graph.NewOperation(node_def.Op, node_def.Name);


// Add inputs // Add inputs
if(inputs != null && inputs.Count > 0) if(inputs != null && inputs.Count > 0)
{ {
foreach (var op_input in inputs)
/*foreach (var op_input in inputs)
{ {
c_api.TF_AddInput(op_desc, op_input._as_tf_output()); c_api.TF_AddInput(op_desc, op_input._as_tf_output());
}
}*/


//c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count);
c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count);
} }


var status = new Status(); var status = new Status();
@@ -48,9 +48,10 @@ namespace Tensorflow
var bytes = attr.Value.ToByteArray(); var bytes = attr.Value.ToByteArray();
var proto = Marshal.AllocHGlobal(bytes.Length); var proto = Marshal.AllocHGlobal(bytes.Length);
Marshal.Copy(bytes, 0, proto, bytes.Length); Marshal.Copy(bytes, 0, proto, bytes.Length);
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status);
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (uint)bytes.Length, status: status);


if(status.Code != TF_Code.TF_OK) throw new Exception(status.Message);
status.Check(true);
} }


var c_op = c_api.TF_FinishOperation(op_desc, status); var c_op = c_api.TF_FinishOperation(op_desc, status);
@@ -60,6 +61,11 @@ namespace Tensorflow
return c_op; return c_op;
} }


public static OpDef _get_op_def(Graph graph, string type)
{
return graph.GetOpDef(type);
}

public static NodeDef _NodeDef(string op_type, string name, string device = "", Dictionary<string, AttrValue> attrs = null) public static NodeDef _NodeDef(string op_type, string name, string device = "", Dictionary<string, AttrValue> attrs = null)
{ {
var node_def = new node_def_pb2.NodeDef(); var node_def = new node_def_pb2.NodeDef();


+ 20
- 5
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -115,12 +115,27 @@ namespace Tensorflow
run_metadata: IntPtr.Zero, run_metadata: IntPtr.Zero,
status: status); status: status);


var result = output_values.Select(x => c_api.TF_TensorData(x))
.Select(x => (object)*(float*)x)
.ToArray();
object[] result = new object[fetch_list.Length];


var op = new Operation(fetch_list[0].oper);
//var metadata = c_api.TF_OperationGetAttrMetadata(fetch_list[0].oper, "dtype", status);
for (int i = 0; i < fetch_list.Length; i++)
{
var tensor = new Tensor(output_values[i]);
switch (tensor.dtype)
{
case TF_DataType.TF_STRING:
// wired, don't know why we have to start from offset 9.
var bytes = tensor.Data();
result[i] = UTF8Encoding.Default.GetString(bytes, 9, bytes.Length - 9);
break;
case TF_DataType.TF_FLOAT:
result[i] = *(float*)c_api.TF_TensorData(output_values[i]);
break;
default:
throw new NotImplementedException("can't get output");
break;
}
}


return result; return result;
} }


+ 1
- 1
src/TensorFlowNET.Core/Status/Status.cs View File

@@ -8,7 +8,7 @@ namespace Tensorflow
/// TF_Status holds error information. It either has an OK code, or /// TF_Status holds error information. It either has an OK code, or
/// else an error code with an associated error message. /// else an error code with an associated error message.
/// </summary> /// </summary>
public class Status
public class Status : IDisposable
{ {
private readonly IntPtr _handle; private readonly IntPtr _handle;




+ 0
- 1
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -146,7 +146,6 @@ namespace Tensorflow
{ {
var data = new byte[bytesize]; var data = new byte[bytesize];
Marshal.Copy(buffer, data, 0, (int)bytesize); Marshal.Copy(buffer, data, 0, (int)bytesize);

return data; return data;
} }




+ 7
- 3
src/TensorFlowNET.Core/Tensors/constant_op.cs View File

@@ -22,17 +22,21 @@ namespace Tensorflow
public static Tensor Create(NDArray nd, string name = "Const", bool verify_shape = false) public static Tensor Create(NDArray nd, string name = "Const", bool verify_shape = false)
{ {
Graph g = ops.get_default_graph(); Graph g = ops.get_default_graph();
var tensor_value = new AttrValue();
var tensor_pb = tensor_util.make_tensor_proto(nd, verify_shape); var tensor_pb = tensor_util.make_tensor_proto(nd, verify_shape);
tensor_value.Tensor = tensor_pb;
var tensor_value = new AttrValue
{
Type = tensor_pb.Dtype,
Tensor = tensor_pb
};
var dtype_value = new AttrValue var dtype_value = new AttrValue
{ {
Type = tensor_value.Tensor.Dtype, Type = tensor_value.Tensor.Dtype,
}; };


var attrs = new Dictionary<string, AttrValue>(); var attrs = new Dictionary<string, AttrValue>();
attrs["dtype"] = dtype_value;
attrs["value"] = tensor_value; attrs["value"] = tensor_value;
attrs["dtype"] = dtype_value;


var op = g.create_op("Const", var op = g.create_op("Const",
null, null,


+ 2
- 1
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -2,6 +2,7 @@
using NumSharp.Core.Interfaces; using NumSharp.Core.Interfaces;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using System.Text; using System.Text;
using tensor_pb2 = Tensorflow; using tensor_pb2 = Tensorflow;


@@ -32,7 +33,7 @@ namespace Tensorflow
tensor_proto.DoubleVal.AddRange(nd.Data<double>()); tensor_proto.DoubleVal.AddRange(nd.Data<double>());
break; break;
case "String": case "String":
tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFrom(nd.Data<string>()[0], Encoding.UTF8));
tensor_proto.StringVal.AddRange(nd.Data<string>().Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x)));
break; break;
default: default:
throw new Exception("Not Implemented"); throw new Exception("Not Implemented");


Loading…
Cancel
Save