diff --git a/src/TensorFlowNET.Core/Attributes/c_api.ops.cs b/src/TensorFlowNET.Core/Attributes/c_api.ops.cs index 9dc36fa8..4e99300d 100644 --- a/src/TensorFlowNET.Core/Attributes/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Attributes/c_api.ops.cs @@ -18,7 +18,7 @@ namespace Tensorflow /// TF_Status* /// [DllImport(TensorFlowLibName)] - public static extern IntPtr TF_OperationGetAttrMetadata(IntPtr oper, string attr_name, IntPtr status); + public static extern TF_AttrMetadata TF_OperationGetAttrMetadata(IntPtr oper, string attr_name, IntPtr status); /// /// Fills in `value` with the value of the attribute `attr_name`. `value` must @@ -71,7 +71,7 @@ namespace Tensorflow /// const void* /// size_t [DllImport(TensorFlowLibName)] - public static extern void TF_SetAttrString(IntPtr desc, string attr_name, IntPtr value, uint length); + public static extern void TF_SetAttrString(IntPtr desc, string attr_name, string value, uint length); /// /// diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 5d1f4f83..a4b23294 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -145,6 +145,11 @@ namespace Tensorflow return ret; } + public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s) + { + return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s); + } + public NodeDef GetNodeDef() { using (var s = new Status()) diff --git a/src/TensorFlowNET.Core/Operations/OperationDescription.cs b/src/TensorFlowNET.Core/Operations/OperationDescription.cs index 38190c17..2e9016ff 100644 --- a/src/TensorFlowNET.Core/Operations/OperationDescription.cs +++ b/src/TensorFlowNET.Core/Operations/OperationDescription.cs @@ -8,6 +8,11 @@ namespace Tensorflow { private IntPtr _handle; + public OperationDescription(Graph graph, string opType, string opName) + { + _handle = c_api.TF_NewOperation(graph, opType, opName); + } + public OperationDescription(IntPtr handle) { _handle = handle; @@ -18,6 +23,16 @@ namespace Tensorflow c_api.TF_AddInputList(_handle, inputs, inputs.Length); } + public void SetAttrType(string attr_name, TF_DataType value) + { + c_api.TF_SetAttrType(_handle, attr_name, value); + } + + public void SetAttrShape(string attr_name, long[] dims) + { + c_api.TF_SetAttrShape(_handle, attr_name, dims, dims.Length); + } + public Operation FinishOperation(Status status) { return c_api.TF_FinishOperation(_handle, status); diff --git a/src/TensorFlowNET.Core/Operations/TF_AttrMetadata.cs b/src/TensorFlowNET.Core/Operations/TF_AttrMetadata.cs index 41ae6a3b..dd73da93 100644 --- a/src/TensorFlowNET.Core/Operations/TF_AttrMetadata.cs +++ b/src/TensorFlowNET.Core/Operations/TF_AttrMetadata.cs @@ -6,7 +6,7 @@ namespace Tensorflow { public struct TF_AttrMetadata { - public char is_list; + public byte is_list; public long list_size; public TF_AttrType type; public long total_size; diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs index cb3ae75e..ac293dae 100644 --- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs @@ -46,6 +46,16 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern IntPtr TF_FinishOperation(IntPtr desc, IntPtr status); + /// + /// Operation will only be added to *graph when TF_FinishOperation() is + /// called (assuming TF_FinishOperation() does not return an error). + /// *graph must not be deleted until after TF_FinishOperation() is + /// called. + /// + /// TF_Graph* + /// const char* + /// const char* + /// TF_OperationDescription* [DllImport(TensorFlowLibName)] public static extern IntPtr TF_NewOperation(IntPtr graph, string opType, string oper_name); diff --git a/test/TensorFlowNET.Examples/LinearRegression.cs b/test/TensorFlowNET.Examples/LinearRegression.cs index b4535572..ae303fe9 100644 --- a/test/TensorFlowNET.Examples/LinearRegression.cs +++ b/test/TensorFlowNET.Examples/LinearRegression.cs @@ -42,6 +42,13 @@ namespace TensorFlowNET.Examples // Mean squared error var sub = pred - Y; var pow = tf.pow(sub, 2); + + + + + + + var reduce = tf.reduce_sum(pow); var cost = reduce / (2d * n_samples); diff --git a/test/TensorFlowNET.UnitTest/CApiAttributesTestcs.cs b/test/TensorFlowNET.UnitTest/CApiAttributesTestcs.cs index 2a9db62d..b29e4393 100644 --- a/test/TensorFlowNET.UnitTest/CApiAttributesTestcs.cs +++ b/test/TensorFlowNET.UnitTest/CApiAttributesTestcs.cs @@ -65,12 +65,11 @@ namespace TensorFlowNET.UnitTest public void String() { var desc = init("string"); - var handle = Marshal.StringToHGlobalAnsi("bunny"); - c_api.TF_SetAttrString(desc, "v", handle, 5); + c_api.TF_SetAttrString(desc, "v", "bunny", 5); - //var oper = c_api.TF_FinishOperation(desc, s_); - //ASSERT_EQ(TF_Code.TF_OK, s_.Code); - //EXPECT_TF_META(oper, "v", -1, TF_AttrType.TF_ATTR_STRING, 5); + var oper = c_api.TF_FinishOperation(desc, s_); + ASSERT_EQ(TF_Code.TF_OK, s_.Code); + EXPECT_TF_META(oper, "v", -1, TF_AttrType.TF_ATTR_STRING, 5); //var value = new char[5]; //c_api.TF_OperationGetAttrString(oper, "v", value, 5, s_); @@ -78,6 +77,17 @@ namespace TensorFlowNET.UnitTest //EXPECT_EQ("bunny", value, 5)); } + [TestMethod] + public void GetAttributesTest() + { + var desc = graph_.NewOperation("Placeholder", "node"); + desc.SetAttrType("dtype", TF_DataType.TF_FLOAT); + long[] ref_shape = new long[3] { 1, 2, 3 }; + desc.SetAttrShape("shape", ref_shape); + var oper = desc.FinishOperation(s_); + var metadata = oper.GetAttributeMetadata("shape", s_); + } + public void Dispose() { graph_.Dispose();