diff --git a/src/TensorFlowNET.Core/Attributes/c_api.ops.cs b/src/TensorFlowNET.Core/Attributes/c_api.ops.cs
index 9dc36fa8..4e99300d 100644
--- a/src/TensorFlowNET.Core/Attributes/c_api.ops.cs
+++ b/src/TensorFlowNET.Core/Attributes/c_api.ops.cs
@@ -18,7 +18,7 @@ namespace Tensorflow
/// TF_Status*
///
[DllImport(TensorFlowLibName)]
- public static extern IntPtr TF_OperationGetAttrMetadata(IntPtr oper, string attr_name, IntPtr status);
+ public static extern TF_AttrMetadata TF_OperationGetAttrMetadata(IntPtr oper, string attr_name, IntPtr status);
///
/// Fills in `value` with the value of the attribute `attr_name`. `value` must
@@ -71,7 +71,7 @@ namespace Tensorflow
/// const void*
/// size_t
[DllImport(TensorFlowLibName)]
- public static extern void TF_SetAttrString(IntPtr desc, string attr_name, IntPtr value, uint length);
+ public static extern void TF_SetAttrString(IntPtr desc, string attr_name, string value, uint length);
///
///
diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs
index 5d1f4f83..a4b23294 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.cs
@@ -145,6 +145,11 @@ namespace Tensorflow
return ret;
}
+ public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s)
+ {
+ return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s);
+ }
+
public NodeDef GetNodeDef()
{
using (var s = new Status())
diff --git a/src/TensorFlowNET.Core/Operations/OperationDescription.cs b/src/TensorFlowNET.Core/Operations/OperationDescription.cs
index 38190c17..2e9016ff 100644
--- a/src/TensorFlowNET.Core/Operations/OperationDescription.cs
+++ b/src/TensorFlowNET.Core/Operations/OperationDescription.cs
@@ -8,6 +8,11 @@ namespace Tensorflow
{
private IntPtr _handle;
+ public OperationDescription(Graph graph, string opType, string opName)
+ {
+ _handle = c_api.TF_NewOperation(graph, opType, opName);
+ }
+
public OperationDescription(IntPtr handle)
{
_handle = handle;
@@ -18,6 +23,16 @@ namespace Tensorflow
c_api.TF_AddInputList(_handle, inputs, inputs.Length);
}
+ public void SetAttrType(string attr_name, TF_DataType value)
+ {
+ c_api.TF_SetAttrType(_handle, attr_name, value);
+ }
+
+ public void SetAttrShape(string attr_name, long[] dims)
+ {
+ c_api.TF_SetAttrShape(_handle, attr_name, dims, dims.Length);
+ }
+
public Operation FinishOperation(Status status)
{
return c_api.TF_FinishOperation(_handle, status);
diff --git a/src/TensorFlowNET.Core/Operations/TF_AttrMetadata.cs b/src/TensorFlowNET.Core/Operations/TF_AttrMetadata.cs
index 41ae6a3b..dd73da93 100644
--- a/src/TensorFlowNET.Core/Operations/TF_AttrMetadata.cs
+++ b/src/TensorFlowNET.Core/Operations/TF_AttrMetadata.cs
@@ -6,7 +6,7 @@ namespace Tensorflow
{
public struct TF_AttrMetadata
{
- public char is_list;
+ public byte is_list;
public long list_size;
public TF_AttrType type;
public long total_size;
diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs
index cb3ae75e..ac293dae 100644
--- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs
+++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs
@@ -46,6 +46,16 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern IntPtr TF_FinishOperation(IntPtr desc, IntPtr status);
+ ///
+ /// Operation will only be added to *graph when TF_FinishOperation() is
+ /// called (assuming TF_FinishOperation() does not return an error).
+ /// *graph must not be deleted until after TF_FinishOperation() is
+ /// called.
+ ///
+ /// TF_Graph*
+ /// const char*
+ /// const char*
+ /// TF_OperationDescription*
[DllImport(TensorFlowLibName)]
public static extern IntPtr TF_NewOperation(IntPtr graph, string opType, string oper_name);
diff --git a/test/TensorFlowNET.Examples/LinearRegression.cs b/test/TensorFlowNET.Examples/LinearRegression.cs
index b4535572..ae303fe9 100644
--- a/test/TensorFlowNET.Examples/LinearRegression.cs
+++ b/test/TensorFlowNET.Examples/LinearRegression.cs
@@ -42,6 +42,13 @@ namespace TensorFlowNET.Examples
// Mean squared error
var sub = pred - Y;
var pow = tf.pow(sub, 2);
+
+
+
+
+
+
+
var reduce = tf.reduce_sum(pow);
var cost = reduce / (2d * n_samples);
diff --git a/test/TensorFlowNET.UnitTest/CApiAttributesTestcs.cs b/test/TensorFlowNET.UnitTest/CApiAttributesTestcs.cs
index 2a9db62d..b29e4393 100644
--- a/test/TensorFlowNET.UnitTest/CApiAttributesTestcs.cs
+++ b/test/TensorFlowNET.UnitTest/CApiAttributesTestcs.cs
@@ -65,12 +65,11 @@ namespace TensorFlowNET.UnitTest
public void String()
{
var desc = init("string");
- var handle = Marshal.StringToHGlobalAnsi("bunny");
- c_api.TF_SetAttrString(desc, "v", handle, 5);
+ c_api.TF_SetAttrString(desc, "v", "bunny", 5);
- //var oper = c_api.TF_FinishOperation(desc, s_);
- //ASSERT_EQ(TF_Code.TF_OK, s_.Code);
- //EXPECT_TF_META(oper, "v", -1, TF_AttrType.TF_ATTR_STRING, 5);
+ var oper = c_api.TF_FinishOperation(desc, s_);
+ ASSERT_EQ(TF_Code.TF_OK, s_.Code);
+ EXPECT_TF_META(oper, "v", -1, TF_AttrType.TF_ATTR_STRING, 5);
//var value = new char[5];
//c_api.TF_OperationGetAttrString(oper, "v", value, 5, s_);
@@ -78,6 +77,17 @@ namespace TensorFlowNET.UnitTest
//EXPECT_EQ("bunny", value, 5));
}
+ [TestMethod]
+ public void GetAttributesTest()
+ {
+ var desc = graph_.NewOperation("Placeholder", "node");
+ desc.SetAttrType("dtype", TF_DataType.TF_FLOAT);
+ long[] ref_shape = new long[3] { 1, 2, 3 };
+ desc.SetAttrShape("shape", ref_shape);
+ var oper = desc.FinishOperation(s_);
+ var metadata = oper.GetAttributeMetadata("shape", s_);
+ }
+
public void Dispose()
{
graph_.Dispose();