| @@ -9,6 +9,7 @@ namespace TensorFlowNET.UnitTest | |||
| { | |||
| protected TF_Code TF_OK = TF_Code.TF_OK; | |||
| protected TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT; | |||
| protected TF_DataType TF_BOOL = TF_DataType.TF_BOOL; | |||
| protected void EXPECT_TRUE(bool expected, string msg = "") | |||
| => Assert.IsTrue(expected, msg); | |||
| @@ -73,6 +74,9 @@ namespace TensorFlowNET.UnitTest | |||
| protected void TF_DeleteStatus(IntPtr s) | |||
| => c_api.TF_DeleteStatus(s); | |||
| protected void TF_DeleteTensor(IntPtr t) | |||
| => c_api.TF_DeleteTensor(t); | |||
| protected IntPtr TF_TensorData(IntPtr t) | |||
| => c_api.TF_TensorData(t); | |||
| @@ -94,6 +98,9 @@ namespace TensorFlowNET.UnitTest | |||
| protected IntPtr TFE_NewOp(IntPtr ctx, string op_or_function_name, IntPtr status) | |||
| => c_api.TFE_NewOp(ctx, op_or_function_name, status); | |||
| protected IntPtr TFE_NewTensorHandle(IntPtr t, IntPtr status) | |||
| => c_api.TFE_NewTensorHandle(t, status); | |||
| protected void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, IntPtr status) | |||
| => c_api.TFE_Execute(op, retvals, ref num_retvals, status); | |||
| @@ -0,0 +1,57 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using Tensorflow; | |||
| using Tensorflow.Eager; | |||
| using Buffer = System.Buffer; | |||
| using System.Linq; | |||
| namespace TensorFlowNET.UnitTest.Eager | |||
| { | |||
| public partial class CApiEagerTest | |||
| { | |||
| /// <summary> | |||
| /// TEST(CAPI, TestTFE_OpInferMixedTypeInputListAttrs) | |||
| /// </summary> | |||
| [TestMethod] | |||
| public unsafe void OpInferMixedTypeInputListAttrs() | |||
| { | |||
| var status = TF_NewStatus(); | |||
| var opts = TFE_NewContextOptions(); | |||
| var ctx = TFE_NewContext(opts, status); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| TFE_DeleteContextOptions(opts); | |||
| var condition = TestScalarTensorHandle(true); | |||
| var t1 = TestMatrixTensorHandle(); | |||
| var t2 = TestAxisTensorHandle(); | |||
| var assertOp = TFE_NewOp(ctx, "Assert", status); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| TFE_OpAddInput(assertOp, condition, status); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| var data = new[] { condition, t1, t2 }; | |||
| TFE_OpAddInputList(assertOp, data, 3, status); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| var attr_values = Graph.TFE_GetOpDef("Assert").Attr; | |||
| var attr_found = attr_values.First(x => x.Name == "T"); | |||
| EXPECT_NE(attr_found, attr_values.Last()); | |||
| // EXPECT_EQ(attr_found.Type[0], "DT_BOOL"); | |||
| //EXPECT_EQ(attr_found->second.list().type(1), tensorflow::DataType::DT_FLOAT); | |||
| //EXPECT_EQ(attr_found->second.list().type(2), tensorflow::DataType::DT_INT32); | |||
| var retvals = new IntPtr[1]; | |||
| int num_retvals = 1; | |||
| TFE_Execute(assertOp, retvals, ref num_retvals, status); | |||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| TF_DeleteStatus(status); | |||
| TFE_DeleteOp(assertOp); | |||
| TFE_DeleteTensorHandle(condition); | |||
| TFE_DeleteTensorHandle(t1); | |||
| TFE_DeleteTensorHandle(t2); | |||
| TFE_DeleteTensorHandle(retvals[0]); | |||
| TFE_DeleteContext(ctx); | |||
| } | |||
| } | |||
| } | |||
| @@ -120,5 +120,45 @@ namespace TensorFlowNET.UnitTest.Eager | |||
| return var_handle[0]; | |||
| } | |||
| IntPtr TestAxisTensorHandle() | |||
| { | |||
| var dims = new long[] { 1 }; | |||
| var data = new int[] { 1 }; | |||
| var t = c_api.TF_AllocateTensor(TF_DataType.TF_INT32, dims, 1, sizeof(int)); | |||
| memcpy(TF_TensorData(t), data, TF_TensorByteSize(t)); | |||
| var status = TF_NewStatus(); | |||
| var th = c_api.TFE_NewTensorHandle(t, status); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| TF_DeleteTensor(t); | |||
| TF_DeleteStatus(status); | |||
| return th; | |||
| } | |||
| IntPtr TestScalarTensorHandle(bool value) | |||
| { | |||
| var data = new[] { value }; | |||
| var t = c_api.TF_AllocateTensor(TF_BOOL, null, 0, sizeof(bool)); | |||
| memcpy(TF_TensorData(t), data, TF_TensorByteSize(t)); | |||
| var status = TF_NewStatus(); | |||
| var th = TFE_NewTensorHandle(t, status); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| TF_DeleteTensor(t); | |||
| TF_DeleteStatus(status); | |||
| return th; | |||
| } | |||
| IntPtr TestScalarTensorHandle(float value) | |||
| { | |||
| var data = new [] { value }; | |||
| var t = c_api.TF_AllocateTensor(TF_FLOAT, null, 0, sizeof(float)); | |||
| memcpy(TF_TensorData(t), data, TF_TensorByteSize(t)); | |||
| var status = TF_NewStatus(); | |||
| var th = TFE_NewTensorHandle(t, status); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| TF_DeleteTensor(t); | |||
| TF_DeleteStatus(status); | |||
| return th; | |||
| } | |||
| } | |||
| } | |||