| @@ -0,0 +1,14 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public struct TF_AttrMetadata | |||||
| { | |||||
| public char is_list; | |||||
| public long list_size; | |||||
| public TF_AttrType type; | |||||
| public long total_size; | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,19 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public enum TF_AttrType | |||||
| { | |||||
| TF_ATTR_STRING = 0, | |||||
| TF_ATTR_INT = 1, | |||||
| TF_ATTR_FLOAT = 2, | |||||
| TF_ATTR_BOOL = 3, | |||||
| TF_ATTR_TYPE = 4, | |||||
| TF_ATTR_SHAPE = 5, | |||||
| TF_ATTR_TENSOR = 6, | |||||
| TF_ATTR_PLACEHOLDER = 7, | |||||
| TF_ATTR_FUNC = 8 | |||||
| } | |||||
| } | |||||
| @@ -40,6 +40,33 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TF_OperationDevice(IntPtr oper); | public static extern IntPtr TF_OperationDevice(IntPtr oper); | ||||
| /// <summary> | |||||
| /// Fills in `value` with the value of the attribute `attr_name`. `value` must | |||||
| /// point to an array of length at least `max_length` (ideally set to | |||||
| /// TF_AttrMetadata.total_size from TF_OperationGetAttrMetadata(oper, | |||||
| /// attr_name)). | |||||
| /// </summary> | |||||
| /// <param name="oper">TF_Operation*</param> | |||||
| /// <param name="attr_name">const char*</param> | |||||
| /// <param name="status">TF_Status*</param> | |||||
| /// <returns></returns> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern TF_AttrMetadata TF_OperationGetAttrMetadata(IntPtr oper, string attr_name, IntPtr status); | |||||
| /// <summary> | |||||
| /// Fills in `value` with the value of the attribute `attr_name`. `value` must | |||||
| /// point to an array of length at least `max_length` (ideally set to | |||||
| /// TF_AttrMetadata.total_size from TF_OperationGetAttrMetadata(oper, | |||||
| /// attr_name)). | |||||
| /// </summary> | |||||
| /// <param name="oper">TF_Operation*</param> | |||||
| /// <param name="attr_name">const char*</param> | |||||
| /// <param name="value">void* </param> | |||||
| /// <param name="max_length">size_t</param> | |||||
| /// <param name="status">TF_Status*</param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TF_OperationGetAttrString(IntPtr oper, string attr_name, IntPtr value, uint max_length, IntPtr status); | |||||
| /// <summary> | /// <summary> | ||||
| /// Sets `output_attr_value` to the binary-serialized AttrValue proto | /// Sets `output_attr_value` to the binary-serialized AttrValue proto | ||||
| /// representation of the value of the `attr_name` attr of `oper`. | /// representation of the value of the `attr_name` attr of `oper`. | ||||
| @@ -170,6 +197,20 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_SetAttrShape(IntPtr desc, string attr_name, long[] dims, int num_dims); | public static extern void TF_SetAttrShape(IntPtr desc, string attr_name, long[] dims, int num_dims); | ||||
| /// <summary> | |||||
| /// Call some TF_SetAttr*() function for every attr that is not | |||||
| /// inferred from an input and doesn't have a default value you wish to | |||||
| /// keep. | |||||
| /// | |||||
| /// `value` must point to a string of length `length` bytes. | |||||
| /// </summary> | |||||
| /// <param name="desc">TF_OperationDescription*</param> | |||||
| /// <param name="attr_name">const char*</param> | |||||
| /// <param name="value">const void*</param> | |||||
| /// <param name="length">size_t</param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TF_SetAttrString(IntPtr desc, string attr_name, string value, uint length); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_SetAttrTensor(IntPtr desc, string attr_name, IntPtr value, IntPtr status); | public static extern void TF_SetAttrTensor(IntPtr desc, string attr_name, IntPtr value, IntPtr status); | ||||
| @@ -0,0 +1,85 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow; | |||||
| namespace TensorFlowNET.UnitTest | |||||
| { | |||||
| /// <summary> | |||||
| /// tensorflow\c\c_api_test.cc | |||||
| /// `class CApiAttributesTest` | |||||
| /// </summary> | |||||
| [TestClass] | |||||
| public class CApiAttributesTestcs : CApiTest, IDisposable | |||||
| { | |||||
| private Graph graph_; | |||||
| private int counter_; | |||||
| private Status s_; | |||||
| public CApiAttributesTestcs() | |||||
| { | |||||
| s_ = new Status(); | |||||
| graph_ = new Graph(); | |||||
| } | |||||
| private OperationDescription init(string type) | |||||
| { | |||||
| // Construct op_name to match the name used by REGISTER_OP in the | |||||
| // ATTR_TEST_REGISTER calls above. | |||||
| string op_name = "CApiAttributesTestOp"; | |||||
| if (type.Contains("list(")) | |||||
| { | |||||
| op_name += "List"; | |||||
| type = type.Substring(5, type.Length - 6); | |||||
| } | |||||
| op_name += type; | |||||
| return c_api.TF_NewOperation(graph_, op_name, $"name{counter_++}"); | |||||
| } | |||||
| /// <summary> | |||||
| /// REGISTER_OP for CApiAttributesTest test cases. | |||||
| /// Registers two ops, each with a single attribute called 'v'. | |||||
| /// The attribute in one op will have a type 'type', the other | |||||
| /// will have list(type). | |||||
| /// </summary> | |||||
| /// <param name="type"></param> | |||||
| private void ATTR_TEST_REGISTER_OP(string type) | |||||
| { | |||||
| } | |||||
| private void EXPECT_TF_META(Operation oper, string attr_name, int expected_list_size, TF_AttrType expected_type, uint expected_total_size) | |||||
| { | |||||
| var m = c_api.TF_OperationGetAttrMetadata(oper, attr_name, s_); | |||||
| EXPECT_EQ(TF_Code.TF_OK, s_.Code); | |||||
| char e = expected_list_size >= 0 ? (char)1 : (char)0; | |||||
| EXPECT_EQ(e, m.is_list); | |||||
| EXPECT_EQ(expected_list_size, m.list_size); | |||||
| EXPECT_EQ(expected_type, m.type); | |||||
| EXPECT_EQ(expected_total_size, m.total_size); | |||||
| } | |||||
| [TestMethod] | |||||
| public void String() | |||||
| { | |||||
| //var desc = init("string"); | |||||
| //c_api.TF_SetAttrString(desc, "v", "bunny", 5); | |||||
| //var oper = c_api.TF_FinishOperation(desc, s_); | |||||
| //ASSERT_EQ(TF_Code.TF_OK, s_.Code); | |||||
| //EXPECT_TF_META(oper, "v", -1, TF_AttrType.TF_ATTR_STRING, 5); | |||||
| //var value = new char[5]; | |||||
| //c_api.TF_OperationGetAttrString(oper, "v", value, 5, s_); | |||||
| //EXPECT_EQ(TF_Code.TF_OK, s_.Code); | |||||
| //EXPECT_EQ("bunny", value, 5)); | |||||
| } | |||||
| public void Dispose() | |||||
| { | |||||
| graph_.Dispose(); | |||||
| s_.Dispose(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -176,11 +176,6 @@ namespace TensorFlowNET.UnitTest | |||||
| } | } | ||||
| public static Operation Neg(Operation n, Graph graph, Status s, string name = "neg") | public static Operation Neg(Operation n, Graph graph, Status s, string name = "neg") | ||||
| { | |||||
| return NegHelper(n, graph, s, name); | |||||
| } | |||||
| public static Operation NegHelper(Operation n, Graph graph, Status s, string name) | |||||
| { | { | ||||
| OperationDescription desc = c_api.TF_NewOperation(graph, "Neg", name); | OperationDescription desc = c_api.TF_NewOperation(graph, "Neg", name); | ||||
| var neg_input = new TF_Output(n, 0); | var neg_input = new TF_Output(n, 0); | ||||
| @@ -221,19 +216,5 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| return Const(new Tensor(v), graph, s, name); | return Const(new Tensor(v), graph, s, name); | ||||
| } | } | ||||
| public static unsafe IntPtr Int32Tensor(int v) | |||||
| { | |||||
| bool deallocator_called = false; | |||||
| const int num_bytes = sizeof(int); | |||||
| var dotHandle = Marshal.AllocHGlobal(num_bytes * 1); | |||||
| *(int*)dotHandle = v; | |||||
| return c_api.TF_NewTensor(TF_DataType.TF_INT32, new long[0], 0, dotHandle, num_bytes, | |||||
| (IntPtr values, IntPtr len, ref bool closure) => | |||||
| { | |||||
| // Free the original buffer and set flag | |||||
| // Marshal.FreeHGlobal(dotHandle); | |||||
| }, ref deallocator_called); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||