From a132d4b8b953849b184af284c87cbbf10de2fa5d Mon Sep 17 00:00:00 2001 From: Esther2013 Date: Tue, 8 Jan 2019 09:22:57 -0600 Subject: [PATCH] CApiGradientsTest c_api.TF_SetAttrTensor exception threw. --- .../Operations/c_api.ops.cs | 3 + .../CApiGradientsTest.cs | 116 ++++++++++++++++++ test/TensorFlowNET.UnitTest/CApiTest.cs | 45 ++++++- 3 files changed, 161 insertions(+), 3 deletions(-) create mode 100644 test/TensorFlowNET.UnitTest/CApiGradientsTest.cs diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs index 96be50d2..462a4321 100644 --- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs @@ -196,6 +196,9 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, IntPtr status); + [DllImport(TensorFlowLibName)] + public static extern void TF_SetAttrBool(IntPtr desc, string attr_name, bool value); + [DllImport(TensorFlowLibName)] public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, IntPtr proto, uint proto_len, IntPtr status); diff --git a/test/TensorFlowNET.UnitTest/CApiGradientsTest.cs b/test/TensorFlowNET.UnitTest/CApiGradientsTest.cs new file mode 100644 index 00000000..5a7eeba3 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/CApiGradientsTest.cs @@ -0,0 +1,116 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp.Core; +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text; +using Tensorflow; + +namespace TensorFlowNET.UnitTest +{ + /// + /// tensorflow\c\c_api_test.cc + /// `class CApiGradientsTest` + /// + [TestClass] + public class CApiGradientsTest : CApiTest, IDisposable + { + private Graph graph_ = new Graph(); + private Graph expected_graph_ = new Graph(); + private Status s_ = new Status(); + + private void TestGradientsSuccess(bool grad_inputs_provided) + { + var inputs = new TF_Output[2]; + var outputs = new TF_Output[1]; + var grad_outputs = new TF_Output[2]; + var expected_grad_outputs = new TF_Output[2]; + + BuildSuccessGraph(inputs, outputs); + } + + private void BuildSuccessGraph(TF_Output[] inputs, TF_Output[] outputs) + { + // Construct the following graph: + // | + // z| + // | + // MatMul + // / \ + // ^ ^ + // | | + // x| y| + // | | + // | | + // Const_0 Const_1 + // + var const0_val = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; + var const1_val = new float[] { 1.0f, 0.0f, 0.0f, 1.0f }; + var const0 = FloatConst2x2(graph_, s_, const0_val, "Const_0"); + var const1 = FloatConst2x2(graph_, s_, const1_val, "Const_1"); + var matmul = MatMul(graph_, s_, const0, const1, "MatMul"); + inputs[0] = new TF_Output(const0, 0); + inputs[1] = new TF_Output(const1, 0); + outputs[0] = new TF_Output(matmul, 0); + EXPECT_EQ(TF_OK, TF_GetCode(s_)); + } + + private Operation FloatConst2x2(Graph graph, Status s, float[] values, string name) + { + var tensor = FloatTensor2x2(values); + var desc = TF_NewOperation(graph, "Const", name); + TF_SetAttrTensor(desc, "value", tensor, s); + if (TF_GetCode(s) != TF_OK) return IntPtr.Zero; + TF_SetAttrType(desc, "dtype", TF_FLOAT); + var op = TF_FinishOperation(desc, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)); + return op; + } + + private Tensor FloatTensor2x2(float[] values) + { + long[] dims = { 2, 2 }; + Tensor t = c_api.TF_AllocateTensor(TF_FLOAT, dims, 2, sizeof(float) * 4); + Marshal.Copy(values, 0, t, 4); + return t; + } + + private Operation MatMul(Graph graph, Status s, Operation l, Operation r, string name, + bool transpose_a = false, bool transpose_b = false) + { + var desc = TF_NewOperation(graph, "MatMul", name); + if (transpose_a) + { + TF_SetAttrBool(desc, "transpose_a", true); + } + if (transpose_b) + { + TF_SetAttrBool(desc, "transpose_b", true); + } + TF_AddInput(desc, new TF_Output(l, 0)); + TF_AddInput(desc, new TF_Output(r, 0)); + var op = TF_FinishOperation(desc, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)); + return op; + } + + [TestMethod] + public void Gradients_GradInputs() + { + TestGradientsSuccess(true); + } + + [TestMethod] + public void Gradients_NoGradInputs() + { + TestGradientsSuccess(false); + } + + public void Dispose() + { + graph_.Dispose(); + expected_graph_.Dispose(); + s_.Dispose(); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/CApiTest.cs b/test/TensorFlowNET.UnitTest/CApiTest.cs index fea2c6e4..bfecd5c1 100644 --- a/test/TensorFlowNET.UnitTest/CApiTest.cs +++ b/test/TensorFlowNET.UnitTest/CApiTest.cs @@ -2,24 +2,63 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow; namespace TensorFlowNET.UnitTest { public class CApiTest { - public void EXPECT_EQ(object expected, object actual) + protected TF_Code TF_OK = TF_Code.TF_OK; + protected TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT; + + protected void EXPECT_EQ(object expected, object actual) { Assert.AreEqual(expected, actual); } - public void ASSERT_EQ(object expected, object actual) + protected void ASSERT_EQ(object expected, object actual) { Assert.AreEqual(expected, actual); } - public void ASSERT_TRUE(bool condition) + protected void ASSERT_TRUE(bool condition) { Assert.IsTrue(condition); } + + protected OperationDescription TF_NewOperation(Graph graph, string opType, string opName) + { + return c_api.TF_NewOperation(graph, opType, opName); + } + + protected void TF_AddInput(OperationDescription desc, TF_Output input) + { + c_api.TF_AddInput(desc, input); + } + + protected Operation TF_FinishOperation(OperationDescription desc, Status s) + { + return c_api.TF_FinishOperation(desc, s); + } + + protected void TF_SetAttrTensor(OperationDescription desc, string attrName, Tensor value, Status s) + { + c_api.TF_SetAttrTensor(desc, attrName, value, s); + } + + protected void TF_SetAttrType(OperationDescription desc, string attrName, TF_DataType dtype) + { + c_api.TF_SetAttrType(desc, attrName, dtype); + } + + protected void TF_SetAttrBool(OperationDescription desc, string attrName, bool value) + { + c_api.TF_SetAttrBool(desc, attrName, value); + } + + protected TF_Code TF_GetCode(Status s) + { + return s.Code; + } } }