diff --git a/src/TensorFlowNET.Core/c_api.cs b/src/TensorFlowNET.Core/APIs/c_api.cs similarity index 100% rename from src/TensorFlowNET.Core/c_api.cs rename to src/TensorFlowNET.Core/APIs/c_api.cs diff --git a/src/TensorFlowNET.Core/Tensors/tf.constant.cs b/src/TensorFlowNET.Core/APIs/tf.constant.cs similarity index 53% rename from src/TensorFlowNET.Core/Tensors/tf.constant.cs rename to src/TensorFlowNET.Core/APIs/tf.constant.cs index d60fb50e..3b44c13d 100644 --- a/src/TensorFlowNET.Core/Tensors/tf.constant.cs +++ b/src/TensorFlowNET.Core/APIs/tf.constant.cs @@ -9,7 +9,12 @@ namespace Tensorflow { public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false) { - return constant_op.Create(nd, name, verify_shape); + //constant_op.Create(nd, name, verify_shape); + var graph = tf.get_default_graph(); + var tensor = new Tensor(nd); + var op = graph.NewOperation("Const", name, tensor); + + return null; } } } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs new file mode 100644 index 00000000..8cd202e4 --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs @@ -0,0 +1,38 @@ +using NumSharp.Core; +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text; + +namespace Tensorflow +{ + public partial class Graph + { + public Operation NewOperation(string opType, string opName, Tensor tensor) + { + var desc = c_api.TF_NewOperation(_handle, opType, opName); + + if (tensor.dtype == TF_DataType.TF_STRING) + { + var value = "Hello World!"; + var bytes = Encoding.UTF8.GetBytes(value); + var buf = Marshal.AllocHGlobal(bytes.Length + 1); + Marshal.Copy(bytes, 0, buf, bytes.Length); + c_api.TF_SetAttrString(desc, "value", buf, (uint)value.Length); + } + else + { + c_api.TF_SetAttrTensor(desc, "value", tensor, Status); + } + + Status.Check(); + + c_api.TF_SetAttrType(desc, "dtype", tensor.dtype); + + var op = c_api.TF_FinishOperation(desc, Status); + Status.Check(); + + return op; + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index a20fe727..8f50f37b 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -41,21 +41,6 @@ namespace Tensorflow _names_in_use = new Dictionary(); } - public Operation NewOperation(string opType, string opName, Tensor t) - { - var desc = c_api.TF_NewOperation(_handle, opType, opName); - - c_api.TF_SetAttrTensor(desc, "value", t, Status); - Status.Check(); - - c_api.TF_SetAttrType(desc, "dtype", t.dtype); - - var op = c_api.TF_FinishOperation(desc, Status); - Status.Check(); - - return op; - } - public T as_graph_element(T obj, bool allow_tensor = true, bool allow_operation = true) { return _as_graph_element_locked(obj, allow_tensor, allow_operation); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index c321af8a..3fa6ca5d 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -83,8 +83,14 @@ namespace Tensorflow Marshal.Copy(nd.Data(), 0, dotHandle, nd.size); break; case "String": - dotHandle = Marshal.StringToHGlobalAuto(nd.Data()[0]); - size = (ulong)nd.Data()[0].Length; + var value = nd.Data()[0]; + var bytes = Encoding.UTF8.GetBytes(value); + var buf = Marshal.AllocHGlobal(bytes.Length + 1); + Marshal.Copy(bytes, 0, buf, bytes.Length); + + //c_api.TF_SetAttrString(op, "value", buf, (uint)bytes.Length); + + size = (ulong)bytes.Length; break; default: throw new NotImplementedException("Marshal.Copy failed.");