diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln index 6af0f6fc..36034437 100644 --- a/TensorFlow.NET.sln +++ b/TensorFlow.NET.sln @@ -1,7 +1,7 @@  Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio Version 16 -VisualStudioVersion = 16.0.29102.190 +# Visual Studio Version 17 +VisualStudioVersion = 17.0.31423.177 MinimumVisualStudioVersion = 10.0.40219.1 Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Binding", "src\TensorFlowNET.Core\Tensorflow.Binding.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}" EndProject @@ -21,6 +21,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Native.UnitTest" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras.UnitTest", "test\TensorFlowNET.Keras.UnitTest\Tensorflow.Keras.UnitTest.csproj", "{79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Graph.UnitTest", "test\TensorFlowNET.Graph.UnitTest\TensorFlowNET.Graph.UnitTest.csproj", "{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -139,6 +141,18 @@ Global {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Release|x64.Build.0 = Release|Any CPU {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Release|x86.ActiveCfg = Release|Any CPU {79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Release|x86.Build.0 = Release|Any CPU + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Debug|Any CPU.Build.0 = Debug|Any CPU + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Debug|x64.ActiveCfg = Debug|x64 + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Debug|x64.Build.0 = Debug|x64 + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Debug|x86.ActiveCfg = Debug|Any CPU + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Debug|x86.Build.0 = Debug|Any CPU + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|Any CPU.ActiveCfg = Release|Any CPU + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|Any CPU.Build.0 = Release|Any CPU + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.ActiveCfg = Release|Any CPU + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.Build.0 = Release|Any CPU + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.ActiveCfg = Release|Any CPU + {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/test/TensorFlowNET.UnitTest/Basics/QueueTest.cs b/test/TensorFlowNET.Graph.UnitTest/Basics/QueueTest.cs similarity index 99% rename from test/TensorFlowNET.UnitTest/Basics/QueueTest.cs rename to test/TensorFlowNET.Graph.UnitTest/Basics/QueueTest.cs index 26907212..4fa1a7da 100644 --- a/test/TensorFlowNET.UnitTest/Basics/QueueTest.cs +++ b/test/TensorFlowNET.Graph.UnitTest/Basics/QueueTest.cs @@ -1,7 +1,6 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System.Linq; using Tensorflow; -using Tensorflow.UnitTest; using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest.Basics diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs b/test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/CondTestCases.cs similarity index 96% rename from test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs rename to test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/CondTestCases.cs index 4fae3de3..917280e4 100644 --- a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs +++ b/test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/CondTestCases.cs @@ -1,9 +1,8 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using Tensorflow; -using Tensorflow.UnitTest; using static Tensorflow.Binding; -namespace TensorFlowNET.UnitTest.control_flow_ops_test +namespace TensorFlowNET.UnitTest.ControlFlowTest { /// /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/ShapeTestCase.cs b/test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/ShapeTestCase.cs similarity index 88% rename from test/TensorFlowNET.UnitTest/control_flow_ops_test/ShapeTestCase.cs rename to test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/ShapeTestCase.cs index df2f6d6d..dc7d5af8 100644 --- a/test/TensorFlowNET.UnitTest/control_flow_ops_test/ShapeTestCase.cs +++ b/test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/ShapeTestCase.cs @@ -1,8 +1,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using Tensorflow; -using Tensorflow.UnitTest; -namespace TensorFlowNET.UnitTest.control_flow_ops_test +namespace TensorFlowNET.UnitTest.ControlFlowTest { /// /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs b/test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/WhileContextTestCase.cs similarity index 95% rename from test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs rename to test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/WhileContextTestCase.cs index 4ebe6cef..81425358 100644 --- a/test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs +++ b/test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/WhileContextTestCase.cs @@ -1,10 +1,9 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; using Tensorflow; -using Tensorflow.UnitTest; using static Tensorflow.Binding; -namespace TensorFlowNET.UnitTest.control_flow_ops_test +namespace TensorFlowNET.UnitTest.ControlFlowTest { [TestClass] public class WhileContextTestCase : GraphModeTestBase diff --git a/test/TensorFlowNET.UnitTest/functional_ops_test/ScanTestCase.cs b/test/TensorFlowNET.Graph.UnitTest/FunctionalOpsTest/ScanTestCase.cs similarity index 75% rename from test/TensorFlowNET.UnitTest/functional_ops_test/ScanTestCase.cs rename to test/TensorFlowNET.Graph.UnitTest/FunctionalOpsTest/ScanTestCase.cs index cf871cec..6e9c707a 100644 --- a/test/TensorFlowNET.UnitTest/functional_ops_test/ScanTestCase.cs +++ b/test/TensorFlowNET.Graph.UnitTest/FunctionalOpsTest/ScanTestCase.cs @@ -2,10 +2,9 @@ using Tensorflow.NumPy; using System; using Tensorflow; -using Tensorflow.UnitTest; using static Tensorflow.Binding; -namespace TensorFlowNET.UnitTest.functional_ops_test +namespace TensorFlowNET.UnitTest.FunctionalOpsTest { /// /// https://www.tensorflow.org/api_docs/python/tf/scan @@ -22,7 +21,8 @@ namespace TensorFlowNET.UnitTest.functional_ops_test var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(6)); var scan = functional_ops.scan(fn, input); - sess.run(scan, (input, np.array(1, 2, 3, 4, 5, 6))).Should().Be(np.array(1, 3, 6, 10, 15, 21)); + var result = sess.run(scan, (input, np.array(1, 2, 3, 4, 5, 6))); + Assert.AreEqual(result, np.array(1, 3, 6, 10, 15, 21)); } [TestMethod, Ignore("need UpdateEdge API")] @@ -34,7 +34,8 @@ namespace TensorFlowNET.UnitTest.functional_ops_test var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(6)); var scan = functional_ops.scan(fn, input, reverse: true); - sess.run(scan, (input, np.array(1, 2, 3, 4, 5, 6))).Should().Be(np.array(21, 20, 18, 15, 11, 6)); + var result = sess.run(scan, (input, np.array(1, 2, 3, 4, 5, 6))); + Assert.AreEqual(result, np.array(21, 20, 18, 15, 11, 6)); } } } diff --git a/test/TensorFlowNET.UnitTest/GradientTest/GradientTest.cs b/test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs similarity index 99% rename from test/TensorFlowNET.UnitTest/GradientTest/GradientTest.cs rename to test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs index 143b40e3..246488a9 100644 --- a/test/TensorFlowNET.UnitTest/GradientTest/GradientTest.cs +++ b/test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs @@ -4,7 +4,6 @@ using System; using System.Collections.Generic; using System.Linq; using Tensorflow; -using Tensorflow.UnitTest; using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest.Gradient diff --git a/test/TensorFlowNET.UnitTest/GraphModeTestBase.cs b/test/TensorFlowNET.Graph.UnitTest/GraphModeTestBase.cs similarity index 74% rename from test/TensorFlowNET.UnitTest/GraphModeTestBase.cs rename to test/TensorFlowNET.Graph.UnitTest/GraphModeTestBase.cs index 8d008ddb..bb3910b9 100644 --- a/test/TensorFlowNET.UnitTest/GraphModeTestBase.cs +++ b/test/TensorFlowNET.Graph.UnitTest/GraphModeTestBase.cs @@ -1,9 +1,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; -using TensorFlowNET.UnitTest; using static Tensorflow.Binding; -using static Tensorflow.KerasApi; -namespace Tensorflow.UnitTest +namespace TensorFlowNET.UnitTest { public class GraphModeTestBase : PythonTest { @@ -16,7 +14,6 @@ namespace Tensorflow.UnitTest [TestCleanup] public void TestClean() { - keras.backend.clear_session(); tf.enable_eager_execution(); } } diff --git a/test/TensorFlowNET.UnitTest/ImageTest.cs b/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs similarity index 90% rename from test/TensorFlowNET.UnitTest/ImageTest.cs rename to test/TensorFlowNET.Graph.UnitTest/ImageTest.cs index 02f88fe6..39a004f0 100644 --- a/test/TensorFlowNET.UnitTest/ImageTest.cs +++ b/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs @@ -1,12 +1,10 @@ -using FluentAssertions; -using Microsoft.VisualStudio.TestTools.UnitTesting; +using Microsoft.VisualStudio.TestTools.UnitTesting; using Tensorflow.NumPy; using System.Linq; using Tensorflow; -using Tensorflow.UnitTest; using static Tensorflow.Binding; -namespace TensorFlowNET.UnitTest.Basics +namespace TensorFlowNET.UnitTest { /// /// Find more examples in https://www.programcreek.com/python/example/90444/tensorflow.read_file @@ -84,14 +82,14 @@ namespace TensorFlowNET.UnitTest.Basics var result = sess.run(cropped); // check if cropped to 1x1 center was succesfull - result.size.Should().Be(1); - result[0, 0, 0, 0].Should().Be(4f); + Assert.AreEqual(result.size, 1); + Assert.AreEqual(result[0, 0, 0, 0], 4f); cropped = tf.image.crop_and_resize(image2, box, boxInd, cropSize2_2); result = sess.run(cropped); // check if flipped and no cropping occured - result.size.Should().Be(16); - result[0, 0, 0, 0].Should().Be(12f); + Assert.AreEqual(result.size, 16); + Assert.AreEqual(result[0, 0, 0, 0], 12f); } } diff --git a/test/TensorFlowNET.UnitTest/MultithreadingTests.cs b/test/TensorFlowNET.Graph.UnitTest/MultithreadingTests.cs similarity index 84% rename from test/TensorFlowNET.UnitTest/MultithreadingTests.cs rename to test/TensorFlowNET.Graph.UnitTest/MultithreadingTests.cs index f1c2e633..c30818e6 100644 --- a/test/TensorFlowNET.UnitTest/MultithreadingTests.cs +++ b/test/TensorFlowNET.Graph.UnitTest/MultithreadingTests.cs @@ -1,12 +1,10 @@ -using FluentAssertions; -using Microsoft.VisualStudio.TestTools.UnitTesting; +using Microsoft.VisualStudio.TestTools.UnitTesting; using Tensorflow.NumPy; using System; using System.IO; using System.Linq; using System.Runtime.InteropServices; using Tensorflow; -using Tensorflow.UnitTest; using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest @@ -24,15 +22,15 @@ namespace TensorFlowNET.UnitTest //the core method void Core(int tid) { - tf.peak_default_graph().Should().BeNull(); + Assert.IsNull(tf.peak_default_graph()); using (var sess = tf.Session()) { var default_graph = tf.peak_default_graph(); - var sess_graph = sess.GetPrivate("_graph"); - sess_graph.Should().NotBeNull(); - default_graph.Should().NotBeNull() - .And.BeEquivalentTo(sess_graph); + var sess_graph = sess.graph; + Assert.IsNotNull(default_graph); + Assert.IsNotNull(sess_graph); + Assert.AreEqual(default_graph, sess_graph); } } } @@ -47,15 +45,15 @@ namespace TensorFlowNET.UnitTest //the core method void Core(int tid) { - tf.peak_default_graph().Should().BeNull(); + Assert.IsNull(tf.peak_default_graph()); //tf.Session created an other graph using (var sess = tf.Session()) { var default_graph = tf.peak_default_graph(); - var sess_graph = sess.GetPrivate("_graph"); - sess_graph.Should().NotBeNull(); - default_graph.Should().NotBeNull() - .And.BeEquivalentTo(sess_graph); + var sess_graph = sess.graph; + Assert.IsNotNull(default_graph); + Assert.IsNotNull(sess_graph); + Assert.AreEqual(default_graph, sess_graph); } } } @@ -70,19 +68,18 @@ namespace TensorFlowNET.UnitTest //the core method void Core(int tid) { - tf.peak_default_graph().Should().BeNull(); + Assert.IsNull(tf.peak_default_graph()); var beforehand = tf.get_default_graph(); //this should create default automatically. - beforehand.graph_key.Should().NotContain("-0/", "Already created a graph in an other thread."); beforehand.as_default(); - tf.peak_default_graph().Should().NotBeNull(); + Assert.IsNotNull(tf.peak_default_graph()); using (var sess = tf.Session()) { var default_graph = tf.peak_default_graph(); - var sess_graph = sess.GetPrivate("_graph"); - sess_graph.Should().NotBeNull(); - default_graph.Should().NotBeNull() - .And.BeEquivalentTo(sess_graph); + var sess_graph = sess.graph; + Assert.IsNotNull(default_graph); + Assert.IsNotNull(sess_graph); + Assert.AreEqual(default_graph, sess_graph); Console.WriteLine($"{tid}-{default_graph.graph_key}"); @@ -188,7 +185,7 @@ namespace TensorFlowNET.UnitTest //the core method void Core(int tid) { - tf.peak_default_graph().Should().BeNull(); + Assert.IsNull(tf.peak_default_graph()); //graph is created automatically to perform create these operations var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); @@ -197,7 +194,8 @@ namespace TensorFlowNET.UnitTest { using (var sess = tf.Session()) { - sess.run(math).GetAtIndex(0).Should().Be(5); + var result = sess.run(math); + Assert.AreEqual(result.GetAtIndex(0), 5f); } } } @@ -213,14 +211,14 @@ namespace TensorFlowNET.UnitTest { using (var sess = tf.Session()) { - tf.peak_default_graph().Should().NotBeNull(); + Assert.IsNotNull(tf.peak_default_graph()); //graph is created automatically to perform create these operations var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); var math = a1 + a2; var result = sess.run(math); - result.GetAtIndex(0).Should().Be(5); + Assert.AreEqual(result.GetAtIndex(0), 5f); } } } @@ -235,7 +233,7 @@ namespace TensorFlowNET.UnitTest { using (var sess = tf.Session()) { - tf.peak_default_graph().Should().NotBeNull(); + Assert.IsNotNull(tf.peak_default_graph()); //graph is created automatically to perform create these operations var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); @@ -252,7 +250,7 @@ namespace TensorFlowNET.UnitTest //the core method void Core(int tid) { - tf.peak_default_graph().Should().BeNull(); + Assert.IsNull(tf.peak_default_graph()); //graph is created automatically to perform create these operations var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); @@ -268,7 +266,7 @@ namespace TensorFlowNET.UnitTest //the core method void Core(int tid) { - tf.peak_default_graph().Should().BeNull(); + Assert.IsNull(tf.peak_default_graph()); //graph is created automatically to perform create these operations var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }, name: "ConstantK"); diff --git a/test/TensorFlowNET.Graph.UnitTest/NameScopeTest.cs b/test/TensorFlowNET.Graph.UnitTest/NameScopeTest.cs new file mode 100644 index 00000000..40763ece --- /dev/null +++ b/test/TensorFlowNET.Graph.UnitTest/NameScopeTest.cs @@ -0,0 +1,78 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.Basics +{ + [TestClass] + public class NameScopeTest : GraphModeTestBase + { + string name = ""; + + [TestMethod] + public void NestedNameScope() + { + Graph g = tf.Graph().as_default(); + + tf_with(new ops.NameScope("scope1"), scope1 => + { + name = scope1; + Assert.AreEqual("scope1", g._name_stack); + Assert.AreEqual("scope1/", name); + + var const1 = tf.constant(1.0); + Assert.AreEqual("scope1/Const:0", const1.name); + + tf_with(new ops.NameScope("scope2"), scope2 => + { + name = scope2; + Assert.AreEqual("scope1/scope2", g._name_stack); + Assert.AreEqual("scope1/scope2/", name); + + var const2 = tf.constant(2.0); + Assert.AreEqual("scope1/scope2/Const:0", const2.name); + }); + + Assert.AreEqual("scope1", g._name_stack); + var const3 = tf.constant(2.0); + Assert.AreEqual("scope1/Const_1:0", const3.name); + }); + + g.Dispose(); + + Assert.AreEqual("", g._name_stack); + } + + [TestMethod, Ignore("Unimplemented Usage")] + public void NestedNameScope_Using() + { + Graph g = tf.Graph().as_default(); + + using (var name = new ops.NameScope("scope1")) + { + Assert.AreEqual("scope1", g._name_stack); + Assert.AreEqual("scope1/", name); + + var const1 = tf.constant(1.0); + Assert.AreEqual("scope1/Const:0", const1.name); + + using (var name2 = new ops.NameScope("scope2")) + { + Assert.AreEqual("scope1/scope2", g._name_stack); + Assert.AreEqual("scope1/scope2/", name); + + var const2 = tf.constant(2.0); + Assert.AreEqual("scope1/scope2/Const:0", const2.name); + } + + Assert.AreEqual("scope1", g._name_stack); + var const3 = tf.constant(2.0); + Assert.AreEqual("scope1/Const_1:0", const3.name); + }; + + g.Dispose(); + + Assert.AreEqual("", g._name_stack); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.Graph.UnitTest/OperationsTest.cs similarity index 99% rename from test/TensorFlowNET.UnitTest/OperationsTest.cs rename to test/TensorFlowNET.Graph.UnitTest/OperationsTest.cs index a85a2f06..df34d51d 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.Graph.UnitTest/OperationsTest.cs @@ -4,8 +4,6 @@ using System; using System.Collections.Generic; using System.Linq; using Tensorflow; -using Tensorflow.UnitTest; -using Tensorflow.Util; using static Tensorflow.Binding; using Buffer = Tensorflow.Buffer; diff --git a/test/TensorFlowNET.Graph.UnitTest/PythonTest.cs b/test/TensorFlowNET.Graph.UnitTest/PythonTest.cs new file mode 100644 index 00000000..8ab25ee6 --- /dev/null +++ b/test/TensorFlowNET.Graph.UnitTest/PythonTest.cs @@ -0,0 +1,335 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Newtonsoft.Json.Linq; +using Tensorflow.NumPy; +using System; +using System.Collections; +using System.Linq; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest +{ + /// + /// Use as base class for test classes to get additional assertions + /// + public class PythonTest + { + #region python compatibility layer + protected PythonTest self { get => this; } + protected int None => -1; + #endregion + + #region pytest assertions + + public void assertItemsEqual(ICollection given, ICollection expected) + { + if (given is Hashtable && expected is Hashtable) + { + Assert.AreEqual(JObject.FromObject(expected).ToString(), JObject.FromObject(given).ToString()); + return; + } + Assert.IsNotNull(expected); + Assert.IsNotNull(given); + var e = expected.OfType().ToArray(); + var g = given.OfType().ToArray(); + Assert.AreEqual(e.Length, g.Length, $"The collections differ in length expected {e.Length} but got {g.Length}"); + for (int i = 0; i < e.Length; i++) + { + /*if (g[i] is NDArray && e[i] is NDArray) + assertItemsEqual((g[i] as NDArray).GetData(), (e[i] as NDArray).GetData()); + else*/ + if (e[i] is ICollection && g[i] is ICollection) + assertEqual(g[i], e[i]); + else + Assert.AreEqual(e[i], g[i], $"Items differ at index {i}, expected {e[i]} but got {g[i]}"); + } + } + + public void assertAllEqual(ICollection given, ICollection expected) + { + assertItemsEqual(given, expected); + } + + public void assertFloat32Equal(float expected, float actual, string msg) + { + float eps = 1e-6f; + Assert.IsTrue(Math.Abs(expected - actual) < eps * Math.Max(1.0f, Math.Abs(expected)), $"{msg}: expected {expected} vs actual {actual}"); + } + + public void assertFloat64Equal(double expected, double actual, string msg) + { + double eps = 1e-16f; + Assert.IsTrue(Math.Abs(expected - actual) < eps * Math.Max(1.0f, Math.Abs(expected)), $"{msg}: expected {expected} vs actual {actual}"); + } + + public void assertEqual(object given, object expected) + { + /*if (given is NDArray && expected is NDArray) + { + assertItemsEqual((given as NDArray).GetData(), (expected as NDArray).GetData()); + return; + }*/ + if (given is Hashtable && expected is Hashtable) + { + Assert.AreEqual(JObject.FromObject(expected).ToString(), JObject.FromObject(given).ToString()); + return; + } + if (given is ICollection && expected is ICollection) + { + assertItemsEqual(given as ICollection, expected as ICollection); + return; + } + if (given is float && expected is float) + { + assertFloat32Equal((float)expected, (float)given, ""); + return; + } + if (given is double && expected is double) + { + assertFloat64Equal((double)expected, (double)given, ""); + return; + } + Assert.AreEqual(expected, given); + } + + public void assertEquals(object given, object expected) + { + assertEqual(given, expected); + } + + public void assert(object given) + { + if (given is bool) + Assert.IsTrue((bool)given); + Assert.IsNotNull(given); + } + + public void assertIsNotNone(object given) + { + Assert.IsNotNull(given); + } + + public void assertFalse(bool cond) + { + Assert.IsFalse(cond); + } + + public void assertTrue(bool cond) + { + Assert.IsTrue(cond); + } + + public void assertAllClose(NDArray array1, NDArray array2, double eps = 1e-5) + { + Assert.IsTrue(np.allclose(array1, array2, rtol: eps)); + } + + public void assertAllClose(double value, NDArray array2, double eps = 1e-5) + { + var array1 = np.ones_like(array2) * value; + Assert.IsTrue(np.allclose(array1, array2, rtol: eps)); + } + + public void assertProtoEquals(object toProto, object o) + { + throw new NotImplementedException(); + } + + #endregion + + #region tensor evaluation and test session + + //protected object _eval_helper(Tensor[] tensors) + //{ + // if (tensors == null) + // return null; + // return nest.map_structure(self._eval_tensor, tensors); + //} + + protected object _eval_tensor(object tensor) + { + if (tensor == null) + return None; + //else if (callable(tensor)) + // return self._eval_helper(tensor()) + else + { + try + { + //TODO: + // if sparse_tensor.is_sparse(tensor): + // return sparse_tensor.SparseTensorValue(tensor.indices, tensor.values, + // tensor.dense_shape) + //return (tensor as Tensor).numpy(); + } + catch (Exception) + { + throw new ValueError("Unsupported type: " + tensor.GetType()); + } + return null; + } + } + + /// + /// This function is used in many original tensorflow unit tests to evaluate tensors + /// in a test session with special settings (for instance constant folding off) + /// + /// + public T evaluate(Tensor tensor) + { + object result = null; + // if context.executing_eagerly(): + // return self._eval_helper(tensors) + // else: + { + using (var sess = tf.Session()) + { + var ndarray = tensor.eval(sess); + if (typeof(T) == typeof(double)) + { + double x = ndarray; + result = x; + } + else if (typeof(T) == typeof(int)) + { + int x = ndarray; + result = x; + } + else + { + result = ndarray; + } + } + + return (T)result; + } + } + + + public Session cached_session() + { + throw new NotImplementedException(); + } + + //Returns a TensorFlow Session for use in executing tests. + public Session session(Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false) + { + //Note that this will set this session and the graph as global defaults. + + //Use the `use_gpu` and `force_gpu` options to control where ops are run.If + //`force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if + //`use_gpu` is True, TensorFlow tries to run as many ops on the GPU as + //possible.If both `force_gpu and `use_gpu` are False, all ops are pinned to + //the CPU. + + //Example: + //```python + //class MyOperatorTest(test_util.TensorFlowTestCase): + // def testMyOperator(self): + // with self.session(use_gpu= True): + // valid_input = [1.0, 2.0, 3.0, 4.0, 5.0] + // result = MyOperator(valid_input).eval() + // self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0] + // invalid_input = [-1.0, 2.0, 7.0] + // with self.assertRaisesOpError("negative input not supported"): + // MyOperator(invalid_input).eval() + //``` + + //Args: + // graph: Optional graph to use during the returned session. + // config: An optional config_pb2.ConfigProto to use to configure the + // session. + // use_gpu: If True, attempt to run as many ops as possible on GPU. + // force_gpu: If True, pin all ops to `/device:GPU:0`. + + //Yields: + // A Session object that should be used as a context manager to surround + // the graph building and execution code in a test case. + + Session s = null; + //if (context.executing_eagerly()) + // yield None + //else + //{ + s = self._create_session(graph, config, force_gpu); + self._constrain_devices_and_set_default(s, use_gpu, force_gpu); + //} + return s.as_default(); + } + + private ITensorFlowObject _constrain_devices_and_set_default(Session sess, bool useGpu, bool forceGpu) + { + //def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu): + //"""Set the session and its graph to global default and constrain devices.""" + //if context.executing_eagerly(): + // yield None + //else: + // with sess.graph.as_default(), sess.as_default(): + // if force_gpu: + // # Use the name of an actual device if one is detected, or + // # '/device:GPU:0' otherwise + // gpu_name = gpu_device_name() + // if not gpu_name: + // gpu_name = "/device:GPU:0" + // with sess.graph.device(gpu_name): + // yield sess + // elif use_gpu: + // yield sess + // else: + // with sess.graph.device("/device:CPU:0"): + // yield sess + return sess; + } + + // See session() for details. + private Session _create_session(Graph graph, object cfg, bool forceGpu) + { + var prepare_config = new Func((config) => + { + // """Returns a config for sessions. + // Args: + // config: An optional config_pb2.ConfigProto to use to configure the + // session. + // Returns: + // A config_pb2.ConfigProto object. + + //TODO: config + + // # use_gpu=False. Currently many tests rely on the fact that any device + // # will be used even when a specific device is supposed to be used. + // allow_soft_placement = not force_gpu + // if config is None: + // config = config_pb2.ConfigProto() + // config.allow_soft_placement = allow_soft_placement + // config.gpu_options.per_process_gpu_memory_fraction = 0.3 + // elif not allow_soft_placement and config.allow_soft_placement: + // config_copy = config_pb2.ConfigProto() + // config_copy.CopyFrom(config) + // config = config_copy + // config.allow_soft_placement = False + // # Don't perform optimizations for tests so we don't inadvertently run + // # gpu ops on cpu + // config.graph_options.optimizer_options.opt_level = -1 + // # Disable Grappler constant folding since some tests & benchmarks + // # use constant input and become meaningless after constant folding. + // # DO NOT DISABLE GRAPPLER OPTIMIZERS WITHOUT CONSULTING WITH THE + // # GRAPPLER TEAM. + // config.graph_options.rewrite_options.constant_folding = ( + // rewriter_config_pb2.RewriterConfig.OFF) + // config.graph_options.rewrite_options.pin_to_host_optimization = ( + // rewriter_config_pb2.RewriterConfig.OFF) + return config; + }); + //TODO: use this instead of normal session + //return new ErrorLoggingSession(graph = graph, config = prepare_config(config)) + return new Session(graph);//, config = prepare_config(config)) + } + + #endregion + + public void AssetSequenceEqual(T[] a, T[] b) + { + Assert.IsTrue(Enumerable.SequenceEqual(a, b)); + } + } +} diff --git a/test/TensorFlowNET.Graph.UnitTest/TensorFlowNET.Graph.UnitTest.csproj b/test/TensorFlowNET.Graph.UnitTest/TensorFlowNET.Graph.UnitTest.csproj new file mode 100644 index 00000000..dc6976ad --- /dev/null +++ b/test/TensorFlowNET.Graph.UnitTest/TensorFlowNET.Graph.UnitTest.csproj @@ -0,0 +1,36 @@ + + + + net5.0 + 9.0 + false + TensorFlowNET.UnitTest + AnyCPU;x64 + + + + DEBUG;TRACE + true + + + + DEBUG;TRACE + true + + + + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + + + diff --git a/test/TensorFlowNET.UnitTest/Utilities/MultiThreadedUnitTestExecuter.cs b/test/TensorFlowNET.Graph.UnitTest/Utilities/MultiThreadedUnitTestExecuter.cs similarity index 100% rename from test/TensorFlowNET.UnitTest/Utilities/MultiThreadedUnitTestExecuter.cs rename to test/TensorFlowNET.Graph.UnitTest/Utilities/MultiThreadedUnitTestExecuter.cs diff --git a/test/TensorFlowNET.Graph.UnitTest/Utilities/TestHelper.cs b/test/TensorFlowNET.Graph.UnitTest/Utilities/TestHelper.cs new file mode 100644 index 00000000..d1cda728 --- /dev/null +++ b/test/TensorFlowNET.Graph.UnitTest/Utilities/TestHelper.cs @@ -0,0 +1,22 @@ +using System; +using System.IO; + +namespace TensorFlowNET.UnitTest +{ + public class TestHelper + { + public static string GetFullPathFromDataDir(string fileName) + { + var dataDir = GetRootContentDir(Directory.GetCurrentDirectory()); + return Path.Combine(dataDir, fileName); + } + + static string GetRootContentDir(string dir) + { + var path = Path.GetFullPath(Path.Combine(dir, "data")); + if (Directory.Exists(path)) + return path; + return GetRootContentDir(Path.GetFullPath(Path.Combine(dir, ".."))); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/Basics/RandomTest.cs b/test/TensorFlowNET.UnitTest/Basics/RandomTest.cs index 57d21a8b..9f471957 100644 --- a/test/TensorFlowNET.UnitTest/Basics/RandomTest.cs +++ b/test/TensorFlowNET.UnitTest/Basics/RandomTest.cs @@ -30,8 +30,8 @@ namespace TensorFlowNET.UnitTest.Basics tf.set_random_seed(1234); var a2 = tf.random_uniform(1); var b2 = tf.random_shuffle(tf.constant(initValue)); - Assert.AreEqual(a1, a2); - Assert.AreEqual(b1, b2); + Assert.AreEqual(a1.numpy(), a2.numpy()); + Assert.AreEqual(b1.numpy(), b2.numpy()); } /// @@ -76,8 +76,8 @@ namespace TensorFlowNET.UnitTest.Basics var a2 = tf.random.normal(1); var b2 = tf.random.truncated_normal(1); - Assert.AreEqual(a1, a2); - Assert.AreEqual(b1, b2); + Assert.AreEqual(a1.numpy(), a2.numpy()); + Assert.AreEqual(b1.numpy(), b2.numpy()); } /// diff --git a/test/TensorFlowNET.UnitTest/Basics/TensorShapeTest.cs b/test/TensorFlowNET.UnitTest/Basics/TensorShapeTest.cs deleted file mode 100644 index 087f2413..00000000 --- a/test/TensorFlowNET.UnitTest/Basics/TensorShapeTest.cs +++ /dev/null @@ -1,67 +0,0 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; -using Tensorflow.NumPy; -using System; -using Tensorflow; -using static Tensorflow.Binding; - -namespace TensorFlowNET.UnitTest.Basics -{ - [TestClass] - public class TensorShapeTest - { - [TestMethod] - public void Case1() - { - int a = 2; - int b = 3; - var dims = new[] { Unknown, a, b }; - new TensorShape(dims).GetPrivate("shape").Should().BeShaped(-1, 2, 3); - } - - [TestMethod] - public void Case2() - { - int a = 2; - int b = 3; - var dims = new[] { Unknown, a, b }; - //new TensorShape(new[] { dims }).GetPrivate("shape").Should().BeShaped(-1, 2, 3); - } - - [TestMethod] - public void Case3() - { - int a = 2; - int b = Unknown; - var dims = new[] { Unknown, a, b }; - //new TensorShape(new[] { dims }).GetPrivate("shape").Should().BeShaped(-1, 2, -1); - } - - [TestMethod] - public void Case4() - { - TensorShape shape = (Unknown, Unknown); - shape.GetPrivate("shape").Should().BeShaped(-1, -1); - } - - [TestMethod] - public void Case5() - { - TensorShape shape = (1, Unknown, 3); - shape.GetPrivate("shape").Should().BeShaped(1, -1, 3); - } - - [TestMethod] - public void Case6() - { - TensorShape shape = (Unknown, 1, 2, 3, Unknown); - shape.GetPrivate("shape").Should().BeShaped(-1, 1, 2, 3, -1); - } - - [TestMethod] - public void Case7() - { - TensorShape shape = new TensorShape(); - Assert.AreEqual(shape.rank, -1); - } - } -} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/EnforcedSinglethreadingTests.cs b/test/TensorFlowNET.UnitTest/EnforcedSinglethreadingTests.cs index c8634542..796ace6c 100644 --- a/test/TensorFlowNET.UnitTest/EnforcedSinglethreadingTests.cs +++ b/test/TensorFlowNET.UnitTest/EnforcedSinglethreadingTests.cs @@ -34,7 +34,7 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session()) { var default_graph = tf.peak_default_graph(); - var sess_graph = sess.GetPrivate("_graph"); + var sess_graph = sess.graph; sess_graph.Should().NotBeNull(); default_graph.Should().NotBeNull() .And.BeEquivalentTo(sess_graph); diff --git a/test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs b/test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs index 4e82e788..29913ce4 100644 --- a/test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs +++ b/test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs @@ -4,7 +4,6 @@ using System.Collections.Generic; using System.Linq; using Tensorflow; using Tensorflow.NumPy; -using Tensorflow.UnitTest; using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest.Gradient diff --git a/test/TensorFlowNET.UnitTest/GradientTest/gradients_test.py b/test/TensorFlowNET.UnitTest/GradientTest/gradients_test.py deleted file mode 100644 index c53afef6..00000000 --- a/test/TensorFlowNET.UnitTest/GradientTest/gradients_test.py +++ /dev/null @@ -1,1104 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tensorflow.ops.gradients.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import sys -import warnings - -import numpy as np - -from tensorflow.python.client import session -from tensorflow.python.eager import backprop -from tensorflow.python.eager import context -from tensorflow.python.eager import function -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function as framework_function -from tensorflow.python.framework import ops -from tensorflow.python.framework import test_ops -from tensorflow.python.framework import test_util -from tensorflow.python.framework.constant_op import constant -from tensorflow.python.layers import core as core_layers -from tensorflow.python.ops import array_grad # pylint: disable=unused-import -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_grad # pylint: disable=unused-import -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import custom_gradient -from tensorflow.python.ops import data_flow_grad # pylint: disable=unused-import -from tensorflow.python.ops import data_flow_ops # pylint: disable=unused-import -from tensorflow.python.ops import functional_ops # pylint: disable=unused-import -from tensorflow.python.ops import gradients -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import list_ops -from tensorflow.python.ops import math_grad # pylint: disable=unused-import -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_grad # pylint: disable=unused-import -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import state_grad # pylint: disable=unused-import -from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import -from tensorflow.python.ops import tensor_array_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.ops.nn_ops import bias_add -from tensorflow.python.platform import googletest - - -class GradientsTest(test_util.TensorFlowTestCase): - - def testGradients(self): - with ops.Graph().as_default(): - inp = constant(1.0, shape=[32, 100], name="in") - w = constant(1.0, shape=[100, 10], name="w") - b = constant(1.0, shape=[10], name="b") - xw = math_ops.matmul(inp, w, name="xw") - h = bias_add(xw, b, name="h") - w_grad = gradients.gradients(h, w)[0] - self.assertEquals("MatMul", w_grad.op.type) - self.assertEquals(w_grad.op._original_op, xw.op) - self.assertTrue(w_grad.op.get_attr("transpose_a")) - self.assertFalse(w_grad.op.get_attr("transpose_b")) - - def testUnusedOutput(self): - with ops.Graph().as_default(): - w = constant(1.0, shape=[2, 2]) - x = constant(1.0, shape=[2, 2]) - wx = math_ops.matmul(w, x) - split_wx = array_ops.split(value=wx, num_or_size_splits=2, axis=0) - c = math_ops.reduce_sum(split_wx[1]) - gw = gradients.gradients(c, [w])[0] - self.assertEquals("MatMul", gw.op.type) - - def testColocateGradients(self): - with ops.Graph().as_default() as g: - w = constant(1.0, shape=[1, 1]) - x = constant(1.0, shape=[1, 2]) - with g.device("/device:GPU:0"): - wx = math_ops.matmul(w, x) - gw = gradients.gradients(wx, [w], colocate_gradients_with_ops=True)[0] - self.assertEqual(gw.op.colocation_groups(), wx.op.colocation_groups()) - - def testColocateGradientsWithAggregation(self): - with ops.Graph().as_default() as g: - with g.device("/device:GPU:1"): - w = constant(1.0, shape=[1, 1]) - x = constant(1.0, shape=[1, 2]) - y = constant(1.0, shape=[1, 2]) - wx = math_ops.matmul(w, x) - wy = math_ops.matmul(w, y) - with g.device("/device:GPU:0"): - z = wx + wy - - gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0] - self.assertEqual(gw1.op.colocation_groups(), wx.op.colocation_groups()) - - gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0] - self.assertTrue(wx.op.colocation_groups() != gw2.op.colocation_groups()) - - def testColocateGradientsWithAggregationInMultipleDevices(self): - with ops.Graph().as_default() as g: - with g.device("/device:GPU:1"): - w = constant(1.0, shape=[1, 1]) - x = constant(1.0, shape=[1, 2]) - y = constant(1.0, shape=[1, 2]) - with g.device("/task:1"): - wx = math_ops.matmul(w, x) - with g.device("/task:2"): - wy = math_ops.matmul(w, y) - with g.device("/device:GPU:0"): - z = wx + wy - - gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0] - self.assertEqual(gw1.op.colocation_groups(), w.op.colocation_groups()) - - gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0] - self.assertTrue(w.op.colocation_groups() != gw2.op.colocation_groups()) - - def testColocateGradientsWithGateGradients(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - with ops.Graph().as_default() as g: - with g.device("/device:CPU:0"): - x = constant(1.0, shape=[1, 1]) - y = constant(1.0, shape=[1, 1]) - s = x + y - with g.device("/device:GPU:0"): - z = math_ops.reduce_sum(s) - - gz_x = gradients.gradients(z, [x], colocate_gradients_with_ops=True, - gate_gradients=True)[0] - with session.Session(): - # Make sure the placer doesn't complain. - self.evaluate(gz_x) - - def testBoundaryStop(self): - # Test that we don't differentiate 'x'. The gradient function for 'x' is - # set explicitly to None so we will get an exception if the gradient code - # tries to differentiate 'x'. - with ops.Graph().as_default(): - c = constant(1.0) - x = array_ops.identity(c) - y = x + 1.0 - z = y + 1 - grads = gradients.gradients(z, [x]) - self.assertTrue(all(x is not None for x in grads)) - - @test_util.run_v1_only("b/120545219") - def testBoundaryContinue(self): - # Test that we differentiate both 'x' and 'y' correctly when x is a - # predecessor of y. - with self.cached_session(): - x = constant(1.0) - y = x * 2.0 - z = y * 3.0 - grads = gradients.gradients(z, [x, y]) - self.assertTrue(all(x is not None for x in grads)) - self.assertEqual(6.0, grads[0].eval()) - - @test_util.run_v1_only("b/120545219") - def testAggregationMethodAccumulateN(self): - with self.cached_session(): - x = constant(1.0) - y = x * 2.0 - z = y + y + y + y + y + y + y + y + y + y - grads = gradients.gradients( - z, [x, y], - aggregation_method=gradients.AggregationMethod. - EXPERIMENTAL_ACCUMULATE_N) - self.assertTrue(all(x is not None for x in grads)) - self.assertEqual(20.0, grads[0].eval()) - self.assertEqual(10.0, grads[1].eval()) - - @test_util.run_v1_only("b/120545219") - def testAggregationMethodAddN(self): - with self.cached_session(): - x = constant(1.0) - y = x * 2.0 - z = y + y + y + y + y + y + y + y + y + y - grads = gradients.gradients( - z, [x, y], aggregation_method=gradients.AggregationMethod.ADD_N) - self.assertTrue(all(x is not None for x in grads)) - self.assertEqual(20.0, grads[0].eval()) - self.assertEqual(10.0, grads[1].eval()) - - @test_util.run_v1_only("b/120545219") - def testAggregationMethodTree(self): - with self.cached_session(): - x = constant(1.0) - y = x * 2.0 - z = y + y + y + y + y + y + y + y + y + y - grads = gradients.gradients( - z, [x, y], - aggregation_method=gradients.AggregationMethod.EXPERIMENTAL_TREE) - self.assertTrue(all(x is not None for x in grads)) - self.assertEqual(20.0, grads[0].eval()) - self.assertEqual(10.0, grads[1].eval()) - - def testNoGradientForStringOutputs(self): - with ops.Graph().as_default(): - - def _TestOpGrad(_, float_grad, string_grad): - """Gradient function for TestStringOutput.""" - self.assertEquals(float_grad.dtype, dtypes.float32) - self.assertFalse(string_grad) - return float_grad - - ops.RegisterGradient("TestStringOutput")(_TestOpGrad) - - c = constant(1.0) - x, _ = test_ops.test_string_output(c) - z = x * 2.0 - w = z * 3.0 - grads = gradients.gradients(z, [c]) - self.assertTrue(isinstance(grads[0], ops.Tensor)) - grads = gradients.gradients(w, [c]) - self.assertTrue(isinstance(grads[0], ops.Tensor)) - - def testSingletonIndexedSlices(self): - with ops.Graph().as_default(): - x = array_ops.placeholder(dtypes.float32) - y = array_ops.identity(x) - dy = ops.IndexedSlices( - array_ops.placeholder(dtypes.float32), - array_ops.placeholder(dtypes.int32)) - dx, = gradients.gradients(y, x, grad_ys=dy) - # The IndexedSlices gradient of tf.identity is the identity map. - with self.cached_session() as sess: - vdx, vdy = sess.run( - [dx, dy], feed_dict={x: [1.0], dy.indices: [0], dy.values: [2.0]}) - self.assertEqual(vdx, vdy) - - @test_util.run_v1_only("b/120545219") - def testNonDifferentiableSwitchInWhileLoop(self): - with ops.Graph().as_default(): - v = array_ops.placeholder(dtypes.float32, []) - - def _Step(i, a, ta): - a += math_ops.cast(v, dtypes.int32) - return (i + 1, a, ta.write(i, a)) - - n = 4 - i, _, ta = control_flow_ops.while_loop( - lambda i, *_: i < n, - _Step, [0, 0, tensor_array_ops.TensorArray( - dtypes.int32, size=n)]) - target = ta.read(i - 1) - grad, = gradients.gradients(target, v) - self.assertIsNone(grad) - - def testVariableReadValueGradient(self): - with ops.Graph().as_default(): - init = constant_op.constant(100.0) - var = variables.Variable(init) - gradient = gradients.gradients(var.read_value(), var) - self.assertIsNotNone(gradient) - - def testVariableAsGraphElementGradient(self): - with ops.Graph().as_default() as graph: - init = constant_op.constant(100.0) - var = variables.Variable(init) - gradient = gradients.gradients(graph.as_graph_element(var), var) - self.assertIsNotNone(gradient) - - @test_util.run_v1_only("b/120545219") - def testVariableRefGradient(self): - with ops.Graph().as_default(): - init = constant_op.constant(100.0) - var = variables.VariableV1(init) - gradient = gradients.gradients(var._ref(), var) - self.assertIsNotNone(gradient) - - @test_util.run_v1_only("b/120545219") - def testDependentYs(self): - with self.cached_session(): - x = constant_op.constant(3.0) - y = math_ops.square(x) - y1 = math_ops.square(y) - y2 = math_ops.square(y1) - g = gradients.gradients([y, y2], x) - self.assertAllClose(17502.0, g[0].eval()) - g = gradients.gradients(y + y2, x) - self.assertAllClose(17502.0, g[0].eval()) - z = array_ops.identity(y) - z2 = array_ops.identity(y2) - g = gradients.gradients([z, z2], x) - self.assertAllClose(17502.0, g[0].eval()) - - @test_util.run_v1_only("b/120545219") - def testPartialDerivatives(self): - with self.cached_session(): - x = constant_op.constant(1.) - y = 2 * x - z = x + y - totalg = gradients.gradients(z, [x, y]) - self.assertEqual([3.0, 1.0], [g.eval() for g in totalg]) - partialg = gradients.gradients(z, [x, y], stop_gradients=[x, y]) - self.assertEqual([1.0, 1.0], [g.eval() for g in partialg]) - - @test_util.run_v1_only("b/120545219") - def testStopGradients(self): - def _MakeGraph(rng, stop_gradients=()): - def _FunctionOf(xs, k=3): - return ops.convert_to_tensor( - sum(math_ops.matmul(rng.rand(k, k), x) for x in xs) - + rng.rand(k, k)) - - a = _FunctionOf([]) - if "a" in stop_gradients: a = array_ops.stop_gradient(a) - b = _FunctionOf([a]) - if "b" in stop_gradients: b = array_ops.stop_gradient(b) - c = _FunctionOf([a, b]) - if "c" in stop_gradients: c = array_ops.stop_gradient(c) - d = _FunctionOf([b, c]) - if "d" in stop_gradients: d = array_ops.stop_gradient(d) - return dict(a=a, b=b, c=c, d=d) - - def _Gradients(ys, xs, **kwargs): - dydxs = gradients.gradients(ys, xs, **kwargs) - dydxs = [0. * x if dydx is None else dydx - for x, dydx in zip(xs, dydxs)] - return dydxs - - seed = np.random.randint(1000) - cases = [] - subsets = [""] + "a b c d ab ac ad bc bd cd abc abd acd bcd abcd".split() - graph = _MakeGraph(np.random.RandomState(seed)) - for constants in subsets: - graph_with_stops = _MakeGraph(np.random.RandomState(seed), constants) - for variables_ in subsets: - # compute the gradient when stopped using tf.stop_gradients - grad1 = _Gradients([graph_with_stops["d"]], - [graph_with_stops[v] for v in variables_]) - # compute the gradient when stopped using the stop_gradients kwarg - grad2 = _Gradients([graph["d"]], - [graph[v] for v in variables_], - stop_gradients=[graph[v] for v in constants]) - cases.append(dict(grad1=grad1, grad2=grad2, - constants=constants, variables=variables_)) - - # evaluate all tensors in one call to session.run for speed - with self.cached_session() as sess: - results = sess.run([(case["grad1"], case["grad2"]) for case in cases]) - - for (npgrad1, npgrad2), case in zip(results, cases): - for a, b in zip(npgrad1, npgrad2): - np.testing.assert_allclose(a, b) - - def testUnconnectedGradientsNoneUnconnectedGradients(self): - with ops.Graph().as_default(): - x = constant(1.0, shape=[2, 2]) - y = constant(3.0, shape=[3, 1]) - grad = gradients.gradients( - [y], [x], unconnected_gradients="none") - self.assertIsNone(grad[0]) - - def testUnconnectedGradientsZerosUnconnectedGradients(self): - with ops.Graph().as_default(): - x = constant(1.0, shape=[2, 2]) - y = constant(3.0, shape=[3, 1]) - grads = gradients.gradients( - [y], [x], unconnected_gradients="zero") - with self.cached_session() as sess: - self.assertAllEqual([[0.0, 0.0], [0.0, 0.0]], self.evaluate(grads)[0]) - - def testUnconnectedGradientsZeroConnectedGradients(self): - with ops.Graph().as_default(): - x = constant(1.0) - y = x * 3.0 - grad = gradients.gradients( - [y], [x], unconnected_gradients="zero") - with self.cached_session() as sess: - self.assertEquals(3.0, self.evaluate(grad)[0]) - - def testUnknownUnconnectedGradientsValueGiven(self): - with ops.Graph().as_default(): - x = constant(1.0) - y = constant(1.0) - with self.assertRaisesRegexp( - ValueError, "Unknown value for unconnected_gradients: 'nonsense'"): - gradients.gradients([y], [x], unconnected_gradients="nonsense") - - -class FunctionGradientsTest(test_util.TensorFlowTestCase): - - @classmethod - def XSquarePlusB(cls, x, b): - return x * x + b - - @classmethod - def XSquarePlusBGradient(cls, x, b, g): - # Perturb gradients (multiply by 2), so we can test that this was called. - g *= 2.0 - return g * 2.0 * x, g - - @classmethod - def _PythonGradient(cls, op, grad): - # Perturb gradients (multiply by 3), so we can test that this was called. - grad *= 3.0 - return grad * op.inputs[0] * 2.0, grad - - @classmethod - def _GetFunc(cls, **kwargs): - return framework_function.Defun(dtypes.float32, dtypes.float32, ** - kwargs)(cls.XSquarePlusB) - - def _GetFuncGradients(self, f, x_value, b_value): - x = constant_op.constant(x_value, name="x") - b = constant_op.constant(b_value, name="b") - - y = f(x, b) - grads = gradients.gradients(y, [x, b]) - with self.cached_session() as sess: - return sess.run(grads) - - def testFunctionGradientsBasic(self): - g = ops.Graph() - with g.as_default(): - f = self._GetFunc() - # Get gradients (should add SymbolicGradient node for function). - grads = self._GetFuncGradients(f, [2.0], [1.0]) - self.assertAllEqual([4.0], grads[0]) - self.assertAllEqual([1.0], grads[1]) - - def testFunctionGradientsComposition(self): - with ops.Graph().as_default(): - f = self._GetFunc() - x = constant_op.constant([2.0], name="x") - b1 = constant_op.constant([1.0], name="b1") - b2 = constant_op.constant([1.0], name="b2") - - y = f(f(x, b1), b2) - # Build gradient graph (should add SymbolicGradient node for function). - grads = gradients.gradients(y, [x, b1]) - - with self.cached_session() as sess: - self.assertAllEqual([40.0], self.evaluate(grads)[0]) - self.assertAllEqual([10.0], self.evaluate(grads)[1]) - - def testFunctionGradientsWithGradFunc(self): - g = ops.Graph() - with g.as_default(): - grad_func = framework_function.Defun(dtypes.float32, dtypes.float32, - dtypes.float32)( - self.XSquarePlusBGradient) - f = self._GetFunc(grad_func=grad_func) - # Get gradients (should add SymbolicGradient node for function, which - # uses the grad_func above, which multiplies all gradients by 2). - grads = self._GetFuncGradients(f, [2.0], [1.0]) - self.assertAllEqual([4.0 * 2], grads[0]) - self.assertAllEqual([1.0 * 2], grads[1]) - - def testFunctionGradientWithRegistration(self): - g = ops.Graph() - with g.as_default(): - f = self._GetFunc(python_grad_func=self._PythonGradient) - # Get gradients, using the python gradient function. It multiplies the - # gradients by 3. - grads = self._GetFuncGradients(f, [2.0], [1.0]) - self.assertAllEqual([4.0 * 3], grads[0]) - self.assertAllEqual([1.0 * 3], grads[1]) - - def testFunctionGradientWithGradFuncAndRegistration(self): - g = ops.Graph() - with g.as_default(): - grad_func = framework_function.Defun(dtypes.float32, dtypes.float32, - dtypes.float32)( - self.XSquarePlusBGradient) - with self.assertRaisesRegexp(ValueError, "Gradient defined twice"): - f = self._GetFunc( - grad_func=grad_func, python_grad_func=self._PythonGradient) - f.add_to_graph(ops.Graph()) - - def testGradientWrtCaptured(self): - with ops.Graph().as_default(): - x = constant_op.constant(1.0, name="x") - - @function.defun() - def Foo(): - y = math_ops.multiply(x, 2.0, name="y") - g = gradients_impl.gradients(y, x) - return g[0] - - f = Foo() - with self.cached_session() as sess: - self.assertEqual(self.evaluate(f), 2.0) - - def testGradientOfCaptured(self): - with ops.Graph().as_default(): - x = constant_op.constant(1.0, name="x") - y = math_ops.multiply(x, 2.0, name="y") - - @framework_function.Defun() - def Foo(): - g = gradients_impl.gradients(y, x) - return g[0] - - f = Foo() - with self.cached_session() as sess: - self.assertEqual(self.evaluate(f), 2.0) - - def testCapturedResourceVariable(self): - with ops.Graph().as_default(): - var = resource_variable_ops.ResourceVariable(1.0, name="var") - - @function.defun() - def Foo(): - y = math_ops.multiply(var, 2.0, name="y") - g = gradients_impl.gradients(y, var) - return g[0] - - f = Foo() - with self.cached_session() as sess: - self.evaluate(variables.global_variables_initializer()) - self.assertEqual(self.evaluate(f), 2.0) - - def testCapturedNested(self): - with ops.Graph().as_default(): - x1 = constant_op.constant(1.0, name="x1") - x2 = constant_op.constant(2.0, name="x2") - x3 = math_ops.multiply(x1, x2, name="x3") - - @function.defun() - def Outer(): - outer1 = array_ops.identity(x1, name="outer1") - - @function.defun() - def Inner(): - inner1 = array_ops.identity(outer1, name="inner1") - inner2 = array_ops.identity(x2, name="inner2") - inner3 = array_ops.identity(x3, name="inner3") - return gradients_impl.gradients([inner1, inner2, inner3, x1], - [x1, x2]) - - return Inner() - - x1_grad, x2_grad = Outer() - with self.cached_session() as sess: - # 1.0 + None + 2.0 + 1.0 = 4.0 - self.assertEqual(self.evaluate(x1_grad), 4.0) - # None + 1.0 + 1.0 + None = 2.0 - self.assertEqual(self.evaluate(x2_grad), 2.0) - - def testCapturedFromFunction(self): - with ops.Graph().as_default(): - x = constant_op.constant(1.0, name="x") - - @function.defun() - def Outer(): - y = math_ops.multiply(x, 2.0, name="y") - - @function.defun() - def Inner(): - z = math_ops.multiply(y, 3.0, name="z") - g = gradients_impl.gradients(z, y) - return g[0] - - return Inner() - - z_grad = Outer() - with self.cached_session() as sess: - self.assertEqual(self.evaluate(z_grad), 3.0) - - def testCapturedEagerTensors(self): - # Test that we can handle captured eager tensors unrelated to the gradient - # computation (i.e. we need to ignore them). - # TODO(skyewm): make it an error if you try to take the gradient wrt a - # captured EagerTensor - with context.eager_mode(): - c = constant_op.constant(2.0, name="c") - - @function.defun - def Foo(): - x = constant_op.constant(10.0, name="x") - y = math_ops.multiply(x, c, name="y") - z = math_ops.multiply(y, 3.0, name="z") - g = gradients_impl.gradients(z, x) - return g[0] - - self.assertEqual(Foo().numpy(), 6.0) - - -class StopGradientTest(test_util.TensorFlowTestCase): - - def testStopGradient(self): - with ops.Graph().as_default(): - inp = constant(1.0, shape=[100, 32], name="in") - out = array_ops.stop_gradient(inp) - igrad = gradients.gradients(out, inp)[0] - assert igrad is None - - -class PreventGradientTest(test_util.TensorFlowTestCase): - - def testPreventGradient(self): - with ops.Graph().as_default(): - inp = constant(1.0, shape=[100, 32], name="in") - out = array_ops.prevent_gradient(inp) - with self.assertRaisesRegexp(LookupError, "explicitly disabled"): - _ = gradients.gradients(out, inp) - - -class HessianVectorProductTest(test_util.TensorFlowTestCase): - - @test_util.run_v1_only("b/120545219") - def testHessianVectorProduct(self): - # Manually compute the Hessian explicitly for a low-dimensional problem - # and check that HessianVectorProduct matches multiplication by the - # explicit Hessian. - # Specifically, the Hessian of f(x) = x^T A x is - # H = A + A^T. - # We expect HessianVectorProduct(f(x), x, v) to be H v. - m = 4 - rng = np.random.RandomState([1, 2, 3]) - mat_value = rng.randn(m, m).astype("float32") - v_value = rng.randn(m, 1).astype("float32") - x_value = rng.randn(m, 1).astype("float32") - hess_value = mat_value + mat_value.T - hess_v_value = np.dot(hess_value, v_value) - for use_gpu in [False, True]: - with self.cached_session(use_gpu=use_gpu): - mat = constant_op.constant(mat_value) - v = constant_op.constant(v_value) - x = constant_op.constant(x_value) - mat_x = math_ops.matmul(mat, x, name="Ax") - x_mat_x = math_ops.matmul(array_ops.transpose(x), mat_x, name="xAx") - hess_v = gradients_impl._hessian_vector_product(x_mat_x, [x], [v])[0] - hess_v_actual = self.evaluate(hess_v) - self.assertAllClose(hess_v_value, hess_v_actual) - - -class HessianTest(test_util.TensorFlowTestCase): - - @test_util.run_v1_only("b/120545219") - def testHessian1D(self): - # Manually compute the Hessian explicitly for a low-dimensional problem - # and check that `hessian` matches. Specifically, the Hessian of - # f(x) = x^T A x is H = A + A^T. - m = 4 - rng = np.random.RandomState([1, 2, 3]) - mat_value = rng.randn(m, m).astype("float32") - x_value = rng.randn(m).astype("float32") - hess_value = mat_value + mat_value.T - with self.session(use_gpu=True): - mat = constant_op.constant(mat_value) - x = constant_op.constant(x_value) - x_mat_x = math_ops.reduce_sum(x[:, None] * mat * x[None, :]) - hess = gradients.hessians(x_mat_x, x)[0] - hess_actual = self.evaluate(hess) - self.assertAllClose(hess_value, hess_actual) - - @test_util.run_v1_only("b/120545219") - def testHessian1D_multi(self): - # Test the computation of the hessian with respect to multiple tensors - m = 4 - n = 3 - rng = np.random.RandomState([1, 2, 3]) - mat_values = [rng.randn(m, m).astype("float32") for _ in range(n)] - x_values = [rng.randn(m).astype("float32") for _ in range(n)] - hess_values = [mat_value + mat_value.T for mat_value in mat_values] - with self.session(use_gpu=True): - mats = [constant_op.constant(mat_value) for mat_value in mat_values] - xs = [constant_op.constant(x_value) for x_value in x_values] - xs_mats_xs = [ - math_ops.reduce_sum(x[:, None] * mat * x[None, :]) - for x, mat in zip(xs, mats) - ] - hessians = gradients.hessians(xs_mats_xs, xs) - hessians_actual = [hess.eval() for hess in hessians] - for hess_value, hess_actual in zip(hess_values, hessians_actual): - self.assertAllClose(hess_value, hess_actual) - - @test_util.run_v1_only("b/120545219") - def testHessianInvalidDimension(self): - for shape in [(10, 10), None]: - with self.cached_session(use_gpu=True): - x = array_ops.placeholder(dtypes.float32, shape) - # Expect a ValueError because the dimensions are wrong - with self.assertRaises(ValueError): - gradients.hessians(x, x) - - @test_util.run_v1_only("b/120545219") - def testHessian2D_square_matrix(self): - # Manually compute the Hessian explicitly for a low-dimensional problem - # and check that `hessian` matches. Specifically, the Hessian of - # f(x) = 1/2 * x^T * x is H = constant (block identity matrix) - m = 3 - rng = np.random.RandomState([1, 2, 3]) - x_value = rng.randn(m, m).astype("float32") - with self.session(use_gpu=True): - x = constant_op.constant(x_value) - x_square = math_ops.reduce_sum( - math_ops.matmul(array_ops.transpose(x), x) * 0.5 - ) - hess = gradients.hessians(x_square, x)[0] - hess_actual = self.evaluate(hess) - hess_value = np.bmat([ - [elem*np.ones((m, m)) for elem in vec] - for vec in np.eye(m) - ]).astype("float32") - self.assertAllEqual((m, m, m, m), hess_actual.shape) - self.assertAllClose(hess_value, hess_actual.reshape((m * m, m * m))) - - @test_util.run_v1_only("b/120545219") - def testHessian2D_non_square_matrix(self): - m = 3 - n = 4 - rng = np.random.RandomState([1, 2, 3]) - x_value = rng.randn(m, n).astype("float32") - with self.session(use_gpu=True): - x = constant_op.constant(x_value) - x_square = math_ops.reduce_sum( - math_ops.matmul(array_ops.transpose(x), x) * 0.5 - ) - hess = gradients.hessians(x_square, x)[0] - hess_actual = self.evaluate(hess) - hess_value = np.bmat([ - [elem*np.ones((n, n)) for elem in vec] - for vec in np.eye(m) - ]).astype("float32") - self.assertAllEqual((m, n, m, n), hess_actual.shape) - self.assertAllClose(hess_value, hess_actual.reshape((m * n, m * n))) - - -class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase): - - @test_util.run_v1_only("b/120545219") - def testIndexedSlicesToTensor(self): - with self.cached_session(): - np_val = np.random.rand(4, 4, 4, 4).astype(np.float32) - c = constant_op.constant(np_val) - c_sparse = math_ops._as_indexed_slices(c) - self.assertAllEqual(np_val.shape, c_sparse.dense_shape.eval()) - c_dense = math_ops.multiply(c_sparse, 1.0) - self.assertAllClose(np_val, self.evaluate(c_dense)) - - @test_util.run_v1_only("b/120545219") - def testIndexedSlicesToTensorList(self): - with self.cached_session(): - numpy_list = [] - dense_list = [] - sparse_list = [] - for _ in range(3): - np_val = np.random.rand(4, 4, 4, 4).astype(np.float32) - c = constant_op.constant(np_val) - c_sparse = math_ops._as_indexed_slices(c) - numpy_list.append(np_val) - dense_list.append(c) - sparse_list.append(c_sparse) - packed_dense = array_ops.stack(dense_list) - packed_sparse = array_ops.stack(sparse_list) - self.assertAllClose(packed_dense.eval(), self.evaluate(packed_sparse)) - - @test_util.run_v1_only("b/120545219") - def testInt64Indices(self): - with self.cached_session(): - np_val = np.random.rand(4, 4, 4, 4).astype(np.float32) - c = constant_op.constant(np_val) - c_sparse = math_ops._as_indexed_slices(c) - c_sparse = ops.IndexedSlices( - c_sparse.values, - math_ops.cast(c_sparse.indices, dtypes.int64), c_sparse.dense_shape) - self.assertAllEqual(np_val.shape, c_sparse.dense_shape.eval()) - c_dense = math_ops.multiply(c_sparse, 1.0) - self.assertAllClose(np_val, self.evaluate(c_dense)) - - @test_util.run_v1_only("b/120545219") - def testWarnings(self): - # TODO(gunan) Reenable after this issue is fixed: - # https://github.com/google/protobuf/issues/2812 - if sys.version_info >= (3, 5): - self.skipTest("Skipped test for Python 3.5+") - - # Smaller than the threshold: no warning. - c_sparse = ops.IndexedSlices( - array_ops.placeholder(dtypes.float32), - array_ops.placeholder(dtypes.int32), constant([4, 4, 4, 4])) - with warnings.catch_warnings(record=True) as w: - math_ops.multiply(c_sparse, 1.0) - self.assertEqual(0, len(w)) - - # Greater than or equal to the threshold: warning. - c_sparse = ops.IndexedSlices( - array_ops.placeholder(dtypes.float32), - array_ops.placeholder(dtypes.int32), constant([100, 100, 100, 100])) - # "always" filter prevents the warning from being suppressed if it was - # already triggered in a different test. - warnings.simplefilter("always") - with warnings.catch_warnings(record=True) as w: - math_ops.multiply(c_sparse, 1.0) - self.assertEqual(1, len(w)) - self.assertTrue( - "with 100000000 elements. This may consume a large amount of memory." in - str(w[0].message)) - - # Unknown dense shape: warning. - c_sparse = ops.IndexedSlices( - array_ops.placeholder(dtypes.float32), - array_ops.placeholder(dtypes.int32), - array_ops.placeholder(dtypes.int32)) - with warnings.catch_warnings(record=True) as w: - math_ops.multiply(c_sparse, 1.0) - self.assertEqual(1, len(w)) - self.assertTrue( - "of unknown shape. This may consume a large amount of memory." in - str(w[0].message)) - - -class OnlyRealGradientsTest(test_util.TensorFlowTestCase): - - @test_util.run_v1_only("b/120545219") - def testRealOnly(self): - x = constant_op.constant(7+3j, dtype=dtypes.complex64) - y = math_ops.square(x) - with self.assertRaisesRegexp( - TypeError, - r"Gradients of complex tensors must set grad_ys " - r"\(y\.dtype = tf\.complex64\)"): - gradients.gradients(y, x) - - -class ResourceCondTest(test_util.TensorFlowTestCase): - - @test_util.run_v1_only("b/120545219") - def testBasic(self): - gamma = resource_variable_ops.ResourceVariable( - np.random.random((3,)), - dtype="float32", name="gamma") - - inputs = array_ops.ones(shape=(3,), dtype="float32") - - def TestFn(): - output = inputs + gamma - return output - - training = array_ops.placeholder_with_default(True, shape=()) - output = control_flow_ops.cond( - training, TestFn, lambda: inputs) - - loss = output - - grads = gradients.gradients( - loss, [gamma]) - self.assertTrue(None not in grads) - - -class CustomGradientTest(test_util.TensorFlowTestCase): - - def testCustomGradientTrivial(self): - - @custom_gradient.custom_gradient - def MyIdentity(x): - - def Grad(dy): - return [3 * dy] - - return x, Grad - - with ops.Graph().as_default(): - x = constant(3.) - y = MyIdentity(MyIdentity(x)) - dy = gradients.gradients(y, x)[0] - with session.Session(): - self.assertEqual(9., self.evaluate(dy)) - - def testCustomGradient(self): - - @custom_gradient.custom_gradient - def MyMultiply(x1, x2): - result = x1 * x2 - - def Grad(dy): - # Switched the ordering here. - return [dy * x1, dy * x2] - - return result, Grad - - with ops.Graph().as_default(): - x1 = constant(3.) - x2 = constant(5.) - y = MyMultiply(x1, x2) - dy = gradients.gradients(y, [x1, x2]) - with session.Session() as sess: - self.assertAllEqual([3., 5.], self.evaluate(dy)) - - def testCustomGradientErrors(self): - - @custom_gradient.custom_gradient - def F(x): - - def Grad(_): - raise RuntimeError("x") - - return x, Grad - - with ops.Graph().as_default(): - x = constant(1.0) - y = F(x) - with self.assertRaises(RuntimeError): - gradients.gradients(y, x) - - def testCustomGradientWithVariables(self): - - @custom_gradient.custom_gradient - def F(x): - out = core_layers.dense(x, 3, use_bias=False) - - def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name - self.assertEqual(1, len(variables)) - grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad) - return grads[0], [array_ops.ones((4, 3))] - - return out, Grad - - with ops.Graph().as_default(): - x = array_ops.ones((2, 4)) - with variable_scope.variable_scope("f", use_resource=True) as vs: - y = F(x) - all_vars = vs.global_variables() - assert len(all_vars) == 1 - grads = gradients.gradients(y, [x, all_vars[0]]) - for g in grads: - self.assertTrue(g is not None) - with session.Session() as sess: - self.evaluate(variables.global_variables_initializer()) - dw = sess.run(math_ops.reduce_sum(grads[1])) - self.assertEqual(12., dw) - - def testCustomGradientWithVariablesEager(self): - with context.eager_mode(): - layer = core_layers.Dense(4, use_bias=False) - - @custom_gradient.custom_gradient - def F(x): - out = layer(x) - - def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name - del out_grad - self.assertEqual(1, len(variables)) - return (array_ops.ones((3, 2)), - [array_ops.ones((2, 4))]) - - return out, Grad - - x = array_ops.ones((3, 2)) + 2. - with backprop.GradientTape() as tape: - tape.watch(x) - y = F(x) - w, = layer.variables - dx, dw = tape.gradient(y, [x, w]) - self.assertEqual(6., math_ops.reduce_sum(dx).numpy()) - self.assertEqual(8., math_ops.reduce_sum(dw).numpy()) - - @test_util.run_v1_only("b/120545219") - def testCustomGradientErrorsWithNonResourceVariables(self): - - def F(x, use_resource=False): - with variable_scope.variable_scope("f", use_resource=use_resource): - out = core_layers.dense(x, 4, use_bias=False) - - def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name - del out_grad - self.assertEqual(1, len(variables)) - return (array_ops.ones((3, 2)), [array_ops.ones((2, 4))]) - - return out, Grad - - @custom_gradient.custom_gradient - def FResource(x): - return F(x, use_resource=True) - - @custom_gradient.custom_gradient - def FNonResource(x): - return F(x, use_resource=False) - - x = array_ops.ones((3, 2)) + 2. - - # Wrapping scope has use_resource=True but inner scope sets to False. Fails. - with variable_scope.variable_scope("vs1", use_resource=True): - with self.assertRaisesWithPredicateMatch(TypeError, - "must be `ResourceVariable`s"): - FNonResource(x) - - # Wrapping scope has use_resource=False but inner scope sets to True. - # Passes. - with variable_scope.variable_scope("vs2", use_resource=False): - FResource(x) - - def testWithNumpyInputs(self): - with context.eager_mode(): - - @custom_gradient.custom_gradient - def F(x): - out = x - - def Grad(_): - return (None, None) - - return out, Grad - - x = np.ones((3, 2), dtype=np.float32) - # Smoke test to ensure numpy inputs are accepted - F(x) - - @test_util.run_v1_only("b/120545219") - def testRVGradientsDynamicCond(self): - with self.cached_session(): - alpha = resource_variable_ops.ResourceVariable( - np.random.random((1,)), - dtype="float32") - - conditional = array_ops.placeholder_with_default(True, shape=()) - output = control_flow_ops.cond( - conditional, lambda: alpha * 2, lambda: alpha * 3) - - g, = gradients_impl.gradients(output, alpha) - self.evaluate(variables.global_variables_initializer()) - self.assertAllEqual(g.eval(), [2.0]) - self.assertAllEqual(g.eval(feed_dict={conditional: False}), [3.0]) - - -class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase): - - def _assert_indexed_slices_equal(self, left, right): - self.assertAllEqual( - self.evaluate(ops.convert_to_tensor(left)), - self.evaluate(ops.convert_to_tensor(right))) - - def testNoGradients(self): - self.assertIsNone(gradients_impl._AggregateIndexedSlicesGradients([])) - - def testOneGradient(self): - t = math_ops._as_indexed_slices(constant_op.constant( - [[1., 2.], [0, 0], [3., 4.]])) - result = gradients_impl._AggregateIndexedSlicesGradients([t]) - self._assert_indexed_slices_equal(t, result) - - def testMultipleGradients(self): - t0 = math_ops._as_indexed_slices(constant_op.constant( - [[1., 2.], [0, 0], [3., 4.]])) - t1 = math_ops._as_indexed_slices(constant_op.constant( - [[0., 0.], [5, 6], [7., 8.]])) - total = constant_op.constant( - [[1., 2.], [5, 6], [10., 12.]]) - result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1]) - self._assert_indexed_slices_equal(total, result) - - def testMultipleGradientsWithNones(self): - t0 = math_ops._as_indexed_slices(constant_op.constant( - [[1., 2.], [0, 0], [3., 4.]])) - t1 = math_ops._as_indexed_slices(constant_op.constant( - [[0., 0.], [5, 6], [7., 8.]])) - t3 = None - total = constant_op.constant( - [[1., 2.], [5, 6], [10., 12.]]) - result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1, t3]) - self._assert_indexed_slices_equal(total, result) - - def testMixedTensorAndIndexedSlices(self): - t0 = math_ops._as_indexed_slices(constant_op.constant( - [[1., 2.], [0, 0], [3., 4.]])) - t1 = constant_op.constant( - [[0., 0.], [5, 6], [7., 8.]]) - total = constant_op.constant( - [[1., 2.], [5, 6], [10., 12.]]) - result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1]) - self._assert_indexed_slices_equal(total, result) - - -class TensorListGradientsTest(test_util.TensorFlowTestCase): - - def testDefaultGradYs(self): - with ops.Graph().as_default(): - tl = list_ops.empty_tensor_list( - element_dtype=dtypes.float32, - element_shape=ops.convert_to_tensor([], dtype=dtypes.int32)) - a = constant(1.0) - tl = list_ops.tensor_list_push_back(tl, a) - - grad_tl = list_ops.empty_tensor_list( - element_dtype=dtypes.float32, - element_shape=ops.convert_to_tensor([], dtype=dtypes.int32)) - grad_tl = list_ops.tensor_list_push_back(tl, constant(5.0)) - - grad = gradients.gradients(tl, a, grad_ys=grad_tl)[0] - with self.cached_session() as sess: - self.assertEquals(self.evaluate(grad), 5.) - - -if __name__ == "__main__": - googletest.main() diff --git a/test/TensorFlowNET.UnitTest/NameScopeTest.cs b/test/TensorFlowNET.UnitTest/NameScopeTest.cs index 886bf056..5bf89f2c 100644 --- a/test/TensorFlowNET.UnitTest/NameScopeTest.cs +++ b/test/TensorFlowNET.UnitTest/NameScopeTest.cs @@ -1,93 +1,22 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using Tensorflow; -using Tensorflow.UnitTest; using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest.Basics { [TestClass] - public class NameScopeTest : GraphModeTestBase + public class NameScopeTest : EagerModeTestBase { string name = ""; - [TestMethod] - public void NestedNameScope() - { - Graph g = tf.Graph().as_default(); - - tf_with(new ops.NameScope("scope1"), scope1 => - { - name = scope1; - Assert.AreEqual("scope1", g._name_stack); - Assert.AreEqual("scope1/", name); - - var const1 = tf.constant(1.0); - Assert.AreEqual("scope1/Const:0", const1.name); - - tf_with(new ops.NameScope("scope2"), scope2 => - { - name = scope2; - Assert.AreEqual("scope1/scope2", g._name_stack); - Assert.AreEqual("scope1/scope2/", name); - - var const2 = tf.constant(2.0); - Assert.AreEqual("scope1/scope2/Const:0", const2.name); - }); - - Assert.AreEqual("scope1", g._name_stack); - var const3 = tf.constant(2.0); - Assert.AreEqual("scope1/Const_1:0", const3.name); - }); - - g.Dispose(); - - Assert.AreEqual("", g._name_stack); - } - [TestMethod] public void NameScopeInEagerMode() { - tf.enable_eager_execution(); - tf_with(new ops.NameScope("scope"), scope => { string name = scope; var const1 = tf.constant(1.0); }); - - tf.compat.v1.disable_eager_execution(); - } - - [TestMethod, Ignore("Unimplemented Usage")] - public void NestedNameScope_Using() - { - Graph g = tf.Graph().as_default(); - - using (var name = new ops.NameScope("scope1")) - { - Assert.AreEqual("scope1", g._name_stack); - Assert.AreEqual("scope1/", name); - - var const1 = tf.constant(1.0); - Assert.AreEqual("scope1/Const:0", const1.name); - - using (var name2 = new ops.NameScope("scope2")) - { - Assert.AreEqual("scope1/scope2", g._name_stack); - Assert.AreEqual("scope1/scope2/", name); - - var const2 = tf.constant(2.0); - Assert.AreEqual("scope1/scope2/Const:0", const2.name); - } - - Assert.AreEqual("scope1", g._name_stack); - var const3 = tf.constant(2.0); - Assert.AreEqual("scope1/Const_1:0", const3.name); - }; - - g.Dispose(); - - Assert.AreEqual("", g._name_stack); } } } diff --git a/test/TensorFlowNET.UnitTest/Utilities/PrivateObject.cs b/test/TensorFlowNET.UnitTest/Utilities/PrivateObject.cs deleted file mode 100644 index 37aa9610..00000000 --- a/test/TensorFlowNET.UnitTest/Utilities/PrivateObject.cs +++ /dev/null @@ -1,917 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. - -namespace Microsoft.VisualStudio.TestTools.UnitTesting -{ - using System; - //using System.Diagnostics; - //using System.Diagnostics.CodeAnalysis; - using System.Globalization; - using System.Reflection; - - /// - /// This class represents the live NON public INTERNAL object in the system - /// - internal class PrivateObject - { - #region Data - - // bind everything - private const BindingFlags BindToEveryThing = BindingFlags.Default | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Public; - -#pragma warning disable CS0414 // The field 'PrivateObject.constructorFlags' is assigned but its value is never used - private static BindingFlags constructorFlags = BindingFlags.Instance | BindingFlags.Public | BindingFlags.CreateInstance | BindingFlags.NonPublic; -#pragma warning restore CS0414 // The field 'PrivateObject.constructorFlags' is assigned but its value is never used - - private object target; // automatically initialized to null - private Type originalType; // automatically initialized to null - - //private Dictionary> methodCache; // automatically initialized to null - - #endregion - - #region Constructors - - ///// - ///// Initializes a new instance of the class that contains - ///// the already existing object of the private class - ///// - ///// object that serves as starting point to reach the private members - ///// the derefrencing string using . that points to the object to be retrived as in m_X.m_Y.m_Z - //[SuppressMessage("Microsoft.Naming", "CA1720:IdentifiersShouldNotContainTypeNames", MessageId = "obj", Justification = "We don't know anything about the object other than that it's an object, so 'obj' seems reasonable")] - //public PrivateObject(object obj, string memberToAccess) - //{ - // Helper.CheckParameterNotNull(obj, "obj", string.Empty); - // ValidateAccessString(memberToAccess); - - // PrivateObject temp = obj as PrivateObject; - // if (temp == null) - // { - // temp = new PrivateObject(obj); - // } - - // // Split The access string - // string[] arr = memberToAccess.Split(new char[] { '.' }); - - // for (int i = 0; i < arr.Length; i++) - // { - // object next = temp.InvokeHelper(arr[i], BindToEveryThing | BindingFlags.Instance | BindingFlags.GetField | BindingFlags.GetProperty, null, CultureInfo.InvariantCulture); - // temp = new PrivateObject(next); - // } - - // this.target = temp.target; - // this.originalType = temp.originalType; - //} - - ///// - ///// Initializes a new instance of the class that wraps the - ///// specified type. - ///// - ///// Name of the assembly - ///// fully qualified name - ///// Argmenets to pass to the constructor - //public PrivateObject(string assemblyName, string typeName, params object[] args) - // : this(assemblyName, typeName, null, args) - //{ - //} - - ///// - ///// Initializes a new instance of the class that wraps the - ///// specified type. - ///// - ///// Name of the assembly - ///// fully qualified name - ///// An array of objects representing the number, order, and type of the parameters for the constructor to get - ///// Argmenets to pass to the constructor - //public PrivateObject(string assemblyName, string typeName, Type[] parameterTypes, object[] args) - // : this(Type.GetType(string.Format(CultureInfo.InvariantCulture, "{0}, {1}", typeName, assemblyName), false), parameterTypes, args) - //{ - // Helper.CheckParameterNotNull(assemblyName, "assemblyName", string.Empty); - // Helper.CheckParameterNotNull(typeName, "typeName", string.Empty); - //} - - ///// - ///// Initializes a new instance of the class that wraps the - ///// specified type. - ///// - ///// type of the object to create - ///// Argmenets to pass to the constructor - //public PrivateObject(Type type, params object[] args) - // : this(type, null, args) - //{ - // Helper.CheckParameterNotNull(type, "type", string.Empty); - //} - - ///// - ///// Initializes a new instance of the class that wraps the - ///// specified type. - ///// - ///// type of the object to create - ///// An array of objects representing the number, order, and type of the parameters for the constructor to get - ///// Argmenets to pass to the constructor - //public PrivateObject(Type type, Type[] parameterTypes, object[] args) - //{ - // Helper.CheckParameterNotNull(type, "type", string.Empty); - // object o; - // if (parameterTypes != null) - // { - // ConstructorInfo ci = type.GetConstructor(BindToEveryThing, null, parameterTypes, null); - // if (ci == null) - // { - // throw new ArgumentException(FrameworkMessages.PrivateAccessorConstructorNotFound); - // } - - // try - // { - // o = ci.Invoke(args); - // } - // catch (TargetInvocationException e) - // { - // Debug.Assert(e.InnerException != null, "Inner exception should not be null."); - // if (e.InnerException != null) - // { - // throw e.InnerException; - // } - - // throw; - // } - // } - // else - // { - // o = Activator.CreateInstance(type, constructorFlags, null, args, null); - // } - - // this.ConstructFrom(o); - //} - - /// - /// Initializes a new instance of the class that wraps - /// the given object. - /// - /// object to wrap - //[SuppressMessage("Microsoft.Naming", "CA1720:IdentifiersShouldNotContainTypeNames", MessageId = "obj", Justification = "We don't know anything about the object other than that it's an object, so 'obj' seems reasonable")] - public PrivateObject(object obj) - { - Helper.CheckParameterNotNull(obj, "obj", string.Empty); - this.ConstructFrom(obj); - } - - /// - /// Initializes a new instance of the class that wraps - /// the given object. - /// - /// object to wrap - /// PrivateType object - //[SuppressMessage("Microsoft.Naming", "CA1720:IdentifiersShouldNotContainTypeNames", MessageId = "obj", Justification = "We don't know anything about the object other than that it's an an object, so 'obj' seems reasonable")] - public PrivateObject(object obj, PrivateType type) - { - Helper.CheckParameterNotNull(type, "type", string.Empty); - this.target = obj; - this.originalType = type.ReferencedType; - } - - #endregion - - ///// - ///// Gets or sets the target - ///// - //public object Target - //{ - // get - // { - // return this.target; - // } - - // set - // { - // Helper.CheckParameterNotNull(value, "Target", string.Empty); - // this.target = value; - // this.originalType = value.GetType(); - // } - //} - - ///// - ///// Gets the type of underlying object - ///// - //public Type RealType - //{ - // get - // { - // return this.originalType; - // } - //} - - //private Dictionary> GenericMethodCache - //{ - // get - // { - // if (this.methodCache == null) - // { - // this.BuildGenericMethodCacheForType(this.originalType); - // } - - // Debug.Assert(this.methodCache != null, "Invalid method cache for type."); - - // return this.methodCache; - // } - //} - - /// - /// returns the hash code of the target object - /// - /// int representing hashcode of the target object - public override int GetHashCode() - { - //Debug.Assert(this.target != null, "target should not be null."); - return this.target.GetHashCode(); - } - - /// - /// Equals - /// - /// Object with whom to compare - /// returns true if the objects are equal. - public override bool Equals(object obj) - { - if (this != obj) - { - //Debug.Assert(this.target != null, "target should not be null."); - if (typeof(PrivateObject) == obj?.GetType()) - { - return this.target.Equals(((PrivateObject)obj).target); - } - else - { - return false; - } - } - - return true; - } - - ///// - ///// Invokes the specified method - ///// - ///// Name of the method - ///// Arguments to pass to the member to invoke. - ///// Result of method call - //public object Invoke(string name, params object[] args) - //{ - // Helper.CheckParameterNotNull(name, "name", string.Empty); - // return this.Invoke(name, null, args, CultureInfo.InvariantCulture); - //} - - ///// - ///// Invokes the specified method - ///// - ///// Name of the method - ///// An array of objects representing the number, order, and type of the parameters for the method to get. - ///// Arguments to pass to the member to invoke. - ///// Result of method call - //public object Invoke(string name, Type[] parameterTypes, object[] args) - //{ - // return this.Invoke(name, parameterTypes, args, CultureInfo.InvariantCulture); - //} - - ///// - ///// Invokes the specified method - ///// - ///// Name of the method - ///// An array of objects representing the number, order, and type of the parameters for the method to get. - ///// Arguments to pass to the member to invoke. - ///// An array of types corresponding to the types of the generic arguments. - ///// Result of method call - //public object Invoke(string name, Type[] parameterTypes, object[] args, Type[] typeArguments) - //{ - // return this.Invoke(name, BindToEveryThing, parameterTypes, args, CultureInfo.InvariantCulture, typeArguments); - //} - - ///// - ///// Invokes the specified method - ///// - ///// Name of the method - ///// Arguments to pass to the member to invoke. - ///// Culture info - ///// Result of method call - //public object Invoke(string name, object[] args, CultureInfo culture) - //{ - // return this.Invoke(name, null, args, culture); - //} - - ///// - ///// Invokes the specified method - ///// - ///// Name of the method - ///// An array of objects representing the number, order, and type of the parameters for the method to get. - ///// Arguments to pass to the member to invoke. - ///// Culture info - ///// Result of method call - //public object Invoke(string name, Type[] parameterTypes, object[] args, CultureInfo culture) - //{ - // return this.Invoke(name, BindToEveryThing, parameterTypes, args, culture); - //} - - ///// - ///// Invokes the specified method - ///// - ///// Name of the method - ///// A bitmask comprised of one or more that specify how the search is conducted. - ///// Arguments to pass to the member to invoke. - ///// Result of method call - //public object Invoke(string name, BindingFlags bindingFlags, params object[] args) - //{ - // return this.Invoke(name, bindingFlags, null, args, CultureInfo.InvariantCulture); - //} - - ///// - ///// Invokes the specified method - ///// - ///// Name of the method - ///// A bitmask comprised of one or more that specify how the search is conducted. - ///// An array of objects representing the number, order, and type of the parameters for the method to get. - ///// Arguments to pass to the member to invoke. - ///// Result of method call - //public object Invoke(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args) - //{ - // return this.Invoke(name, bindingFlags, parameterTypes, args, CultureInfo.InvariantCulture); - //} - - ///// - ///// Invokes the specified method - ///// - ///// Name of the method - ///// A bitmask comprised of one or more that specify how the search is conducted. - ///// Arguments to pass to the member to invoke. - ///// Culture info - ///// Result of method call - //public object Invoke(string name, BindingFlags bindingFlags, object[] args, CultureInfo culture) - //{ - // return this.Invoke(name, bindingFlags, null, args, culture); - //} - - ///// - ///// Invokes the specified method - ///// - ///// Name of the method - ///// A bitmask comprised of one or more that specify how the search is conducted. - ///// An array of objects representing the number, order, and type of the parameters for the method to get. - ///// Arguments to pass to the member to invoke. - ///// Culture info - ///// Result of method call - //public object Invoke(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args, CultureInfo culture) - //{ - // return this.Invoke(name, bindingFlags, parameterTypes, args, culture, null); - //} - - ///// - ///// Invokes the specified method - ///// - ///// Name of the method - ///// A bitmask comprised of one or more that specify how the search is conducted. - ///// An array of objects representing the number, order, and type of the parameters for the method to get. - ///// Arguments to pass to the member to invoke. - ///// Culture info - ///// An array of types corresponding to the types of the generic arguments. - ///// Result of method call - //public object Invoke(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args, CultureInfo culture, Type[] typeArguments) - //{ - // Helper.CheckParameterNotNull(name, "name", string.Empty); - // if (parameterTypes != null) - // { - // bindingFlags |= BindToEveryThing | BindingFlags.Instance; - - // // Fix up the parameter types - // MethodInfo member = this.originalType.GetMethod(name, bindingFlags, null, parameterTypes, null); - - // // If the method was not found and type arguments were provided for generic paramaters, - // // attempt to look up a generic method. - // if ((member == null) && (typeArguments != null)) - // { - // // This method may contain generic parameters...if so, the previous call to - // // GetMethod() will fail because it doesn't fully support generic parameters. - - // // Look in the method cache to see if there is a generic method - // // on the incoming type that contains the correct signature. - // member = this.GetGenericMethodFromCache(name, parameterTypes, typeArguments, bindingFlags, null); - // } - - // if (member == null) - // { - // throw new ArgumentException( - // string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name)); - // } - - // try - // { - // if (member.IsGenericMethodDefinition) - // { - // MethodInfo constructed = member.MakeGenericMethod(typeArguments); - // return constructed.Invoke(this.target, bindingFlags, null, args, culture); - // } - // else - // { - // return member.Invoke(this.target, bindingFlags, null, args, culture); - // } - // } - // catch (TargetInvocationException e) - // { - // Debug.Assert(e.InnerException != null, "Inner exception should not be null."); - // if (e.InnerException != null) - // { - // throw e.InnerException; - // } - - // throw; - // } - // } - // else - // { - // return this.InvokeHelper(name, bindingFlags | BindingFlags.InvokeMethod, args, culture); - // } - //} - - ///// - ///// Gets the array element using array of subsrcipts for each dimension - ///// - ///// Name of the member - ///// the indices of array - ///// An arrya of elements. - //public object GetArrayElement(string name, params int[] indices) - //{ - // Helper.CheckParameterNotNull(name, "name", string.Empty); - // return this.GetArrayElement(name, BindToEveryThing, indices); - //} - - ///// - ///// Sets the array element using array of subsrcipts for each dimension - ///// - ///// Name of the member - ///// Value to set - ///// the indices of array - //public void SetArrayElement(string name, object value, params int[] indices) - //{ - // Helper.CheckParameterNotNull(name, "name", string.Empty); - // this.SetArrayElement(name, BindToEveryThing, value, indices); - //} - - ///// - ///// Gets the array element using array of subsrcipts for each dimension - ///// - ///// Name of the member - ///// A bitmask comprised of one or more that specify how the search is conducted. - ///// the indices of array - ///// An arrya of elements. - //public object GetArrayElement(string name, BindingFlags bindingFlags, params int[] indices) - //{ - // Helper.CheckParameterNotNull(name, "name", string.Empty); - // Array arr = (Array)this.InvokeHelper(name, BindingFlags.GetField | bindingFlags, null, CultureInfo.InvariantCulture); - // return arr.GetValue(indices); - //} - - ///// - ///// Sets the array element using array of subsrcipts for each dimension - ///// - ///// Name of the member - ///// A bitmask comprised of one or more that specify how the search is conducted. - ///// Value to set - ///// the indices of array - //public void SetArrayElement(string name, BindingFlags bindingFlags, object value, params int[] indices) - //{ - // Helper.CheckParameterNotNull(name, "name", string.Empty); - // Array arr = (Array)this.InvokeHelper(name, BindingFlags.GetField | bindingFlags, null, CultureInfo.InvariantCulture); - // arr.SetValue(value, indices); - //} - - ///// - ///// Get the field - ///// - ///// Name of the field - ///// The field. - //public object GetField(string name) - //{ - // Helper.CheckParameterNotNull(name, "name", string.Empty); - // return this.GetField(name, BindToEveryThing); - //} - - ///// - ///// Sets the field - ///// - ///// Name of the field - ///// value to set - //public void SetField(string name, object value) - //{ - // Helper.CheckParameterNotNull(name, "name", string.Empty); - // this.SetField(name, BindToEveryThing, value); - //} - - ///// - ///// Gets the field - ///// - ///// Name of the field - ///// A bitmask comprised of one or more that specify how the search is conducted. - ///// The field. - //public object GetField(string name, BindingFlags bindingFlags) - //{ - // Helper.CheckParameterNotNull(name, "name", string.Empty); - // return this.InvokeHelper(name, BindingFlags.GetField | bindingFlags, null, CultureInfo.InvariantCulture); - //} - - ///// - ///// Sets the field - ///// - ///// Name of the field - ///// A bitmask comprised of one or more that specify how the search is conducted. - ///// value to set - //public void SetField(string name, BindingFlags bindingFlags, object value) - //{ - // Helper.CheckParameterNotNull(name, "name", string.Empty); - // this.InvokeHelper(name, BindingFlags.SetField | bindingFlags, new object[] { value }, CultureInfo.InvariantCulture); - //} - - /// - /// Get the field or property - /// - /// Name of the field or property - /// The field or property. - public object GetFieldOrProperty(string name) - { - Helper.CheckParameterNotNull(name, "name", string.Empty); - return this.GetFieldOrProperty(name, BindToEveryThing); - } - - /// - /// Sets the field or property - /// - /// Name of the field or property - /// value to set - public void SetFieldOrProperty(string name, object value) - { - Helper.CheckParameterNotNull(name, "name", string.Empty); - this.SetFieldOrProperty(name, BindToEveryThing, value); - } - - /// - /// Gets the field or property - /// - /// Name of the field or property - /// A bitmask comprised of one or more that specify how the search is conducted. - /// The field or property. - public object GetFieldOrProperty(string name, BindingFlags bindingFlags) - { - Helper.CheckParameterNotNull(name, "name", string.Empty); - return this.InvokeHelper(name, BindingFlags.GetField | BindingFlags.GetProperty | bindingFlags, null, CultureInfo.InvariantCulture); - } - - /// - /// Sets the field or property - /// - /// Name of the field or property - /// A bitmask comprised of one or more that specify how the search is conducted. - /// value to set - public void SetFieldOrProperty(string name, BindingFlags bindingFlags, object value) - { - Helper.CheckParameterNotNull(name, "name", string.Empty); - this.InvokeHelper(name, BindingFlags.SetField | BindingFlags.SetProperty | bindingFlags, new object[] { value }, CultureInfo.InvariantCulture); - } - - ///// - ///// Gets the property - ///// - ///// Name of the property - ///// Arguments to pass to the member to invoke. - ///// The property. - //public object GetProperty(string name, params object[] args) - //{ - // return this.GetProperty(name, null, args); - //} - - ///// - ///// Gets the property - ///// - ///// Name of the property - ///// An array of objects representing the number, order, and type of the parameters for the indexed property. - ///// Arguments to pass to the member to invoke. - ///// The property. - //public object GetProperty(string name, Type[] parameterTypes, object[] args) - //{ - // return this.GetProperty(name, BindToEveryThing, parameterTypes, args); - //} - - ///// - ///// Set the property - ///// - ///// Name of the property - ///// value to set - ///// Arguments to pass to the member to invoke. - //public void SetProperty(string name, object value, params object[] args) - //{ - // this.SetProperty(name, null, value, args); - //} - - ///// - ///// Set the property - ///// - ///// Name of the property - ///// An array of objects representing the number, order, and type of the parameters for the indexed property. - ///// value to set - ///// Arguments to pass to the member to invoke. - //public void SetProperty(string name, Type[] parameterTypes, object value, object[] args) - //{ - // this.SetProperty(name, BindToEveryThing, value, parameterTypes, args); - //} - - ///// - ///// Gets the property - ///// - ///// Name of the property - ///// A bitmask comprised of one or more that specify how the search is conducted. - ///// Arguments to pass to the member to invoke. - ///// The property. - //public object GetProperty(string name, BindingFlags bindingFlags, params object[] args) - //{ - // return this.GetProperty(name, bindingFlags, null, args); - //} - - ///// - ///// Gets the property - ///// - ///// Name of the property - ///// A bitmask comprised of one or more that specify how the search is conducted. - ///// An array of objects representing the number, order, and type of the parameters for the indexed property. - ///// Arguments to pass to the member to invoke. - ///// The property. - //public object GetProperty(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args) - //{ - // Helper.CheckParameterNotNull(name, "name", string.Empty); - // if (parameterTypes != null) - // { - // PropertyInfo pi = this.originalType.GetProperty(name, bindingFlags, null, null, parameterTypes, null); - // if (pi == null) - // { - // throw new ArgumentException( - // string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name)); - // } - - // return pi.GetValue(this.target, args); - // } - // else - // { - // return this.InvokeHelper(name, bindingFlags | BindingFlags.GetProperty, args, null); - // } - //} - - ///// - ///// Sets the property - ///// - ///// Name of the property - ///// A bitmask comprised of one or more that specify how the search is conducted. - ///// value to set - ///// Arguments to pass to the member to invoke. - //public void SetProperty(string name, BindingFlags bindingFlags, object value, params object[] args) - //{ - // this.SetProperty(name, bindingFlags, value, null, args); - //} - - ///// - ///// Sets the property - ///// - ///// Name of the property - ///// A bitmask comprised of one or more that specify how the search is conducted. - ///// value to set - ///// An array of objects representing the number, order, and type of the parameters for the indexed property. - ///// Arguments to pass to the member to invoke. - //public void SetProperty(string name, BindingFlags bindingFlags, object value, Type[] parameterTypes, object[] args) - //{ - // Helper.CheckParameterNotNull(name, "name", string.Empty); - - // if (parameterTypes != null) - // { - // PropertyInfo pi = this.originalType.GetProperty(name, bindingFlags, null, null, parameterTypes, null); - // if (pi == null) - // { - // throw new ArgumentException( - // string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name)); - // } - - // pi.SetValue(this.target, value, args); - // } - // else - // { - // object[] pass = new object[(args?.Length ?? 0) + 1]; - // pass[0] = value; - // args?.CopyTo(pass, 1); - // this.InvokeHelper(name, bindingFlags | BindingFlags.SetProperty, pass, null); - // } - //} - - #region Private Helpers - - ///// - ///// Validate access string - ///// - ///// access string - //private static void ValidateAccessString(string access) - //{ - // Helper.CheckParameterNotNull(access, "access", string.Empty); - // if (access.Length == 0) - // { - // throw new ArgumentException(FrameworkMessages.AccessStringInvalidSyntax); - // } - - // string[] arr = access.Split('.'); - // foreach (string str in arr) - // { - // if ((str.Length == 0) || (str.IndexOfAny(new char[] { ' ', '\t', '\n' }) != -1)) - // { - // throw new ArgumentException(FrameworkMessages.AccessStringInvalidSyntax); - // } - // } - //} - - /// - /// Invokes the memeber - /// - /// Name of the member - /// Additional attributes - /// Arguments for the invocation - /// Culture - /// Result of the invocation - private object InvokeHelper(string name, BindingFlags bindingFlags, object[] args, CultureInfo culture) - { - Helper.CheckParameterNotNull(name, "name", string.Empty); - //Debug.Assert(this.target != null, "Internal Error: Null reference is returned for internal object"); - - // Invoke the actual Method - try - { - return this.originalType.InvokeMember(name, bindingFlags, null, this.target, args, culture); - } - catch (TargetInvocationException e) - { - //Debug.Assert(e.InnerException != null, "Inner exception should not be null."); - if (e.InnerException != null) - { - throw e.InnerException; - } - - throw; - } - } - - private void ConstructFrom(object obj) - { - Helper.CheckParameterNotNull(obj, "obj", string.Empty); - this.target = obj; - this.originalType = obj.GetType(); - } - - //private void BuildGenericMethodCacheForType(Type t) - //{ - // Debug.Assert(t != null, "type should not be null."); - // this.methodCache = new Dictionary>(); - - // MethodInfo[] members = t.GetMethods(BindToEveryThing); - // LinkedList listByName; // automatically initialized to null - - // foreach (MethodInfo member in members) - // { - // if (member.IsGenericMethod || member.IsGenericMethodDefinition) - // { - // if (!this.GenericMethodCache.TryGetValue(member.Name, out listByName)) - // { - // listByName = new LinkedList(); - // this.GenericMethodCache.Add(member.Name, listByName); - // } - - // Debug.Assert(listByName != null, "list should not be null."); - // listByName.AddLast(member); - // } - // } - //} - - ///// - ///// Extracts the most appropriate generic method signature from the current private type. - ///// - ///// The name of the method in which to search the signature cache. - ///// An array of types corresponding to the types of the parameters in which to search. - ///// An array of types corresponding to the types of the generic arguments. - ///// to further filter the method signatures. - ///// Modifiers for parameters. - ///// A methodinfo instance. - //private MethodInfo GetGenericMethodFromCache(string methodName, Type[] parameterTypes, Type[] typeArguments, BindingFlags bindingFlags, ParameterModifier[] modifiers) - //{ - // Debug.Assert(!string.IsNullOrEmpty(methodName), "Invalid method name."); - // Debug.Assert(parameterTypes != null, "Invalid parameter type array."); - // Debug.Assert(typeArguments != null, "Invalid type arguments array."); - - // // Build a preliminary list of method candidates that contain roughly the same signature. - // var methodCandidates = this.GetMethodCandidates(methodName, parameterTypes, typeArguments, bindingFlags, modifiers); - - // // Search of ambiguous methods (methods with the same signature). - // MethodInfo[] finalCandidates = new MethodInfo[methodCandidates.Count]; - // methodCandidates.CopyTo(finalCandidates, 0); - - // if ((parameterTypes != null) && (parameterTypes.Length == 0)) - // { - // for (int i = 0; i < finalCandidates.Length; i++) - // { - // MethodInfo methodInfo = finalCandidates[i]; - - // if (!RuntimeTypeHelper.CompareMethodSigAndName(methodInfo, finalCandidates[0])) - // { - // throw new AmbiguousMatchException(); - // } - // } - - // // All the methods have the exact same name and sig so return the most derived one. - // return RuntimeTypeHelper.FindMostDerivedNewSlotMeth(finalCandidates, finalCandidates.Length) as MethodInfo; - // } - - // // Now that we have a preliminary list of candidates, select the most appropriate one. - // return RuntimeTypeHelper.SelectMethod(bindingFlags, finalCandidates, parameterTypes, modifiers) as MethodInfo; - //} - - //private LinkedList GetMethodCandidates(string methodName, Type[] parameterTypes, Type[] typeArguments, BindingFlags bindingFlags, ParameterModifier[] modifiers) - //{ - // Debug.Assert(!string.IsNullOrEmpty(methodName), "methodName should not be null."); - // Debug.Assert(parameterTypes != null, "parameterTypes should not be null."); - // Debug.Assert(typeArguments != null, "typeArguments should not be null."); - - // LinkedList methodCandidates = new LinkedList(); - // LinkedList methods = null; - - // if (!this.GenericMethodCache.TryGetValue(methodName, out methods)) - // { - // return methodCandidates; - // } - - // Debug.Assert(methods != null, "methods should not be null."); - - // foreach (MethodInfo candidate in methods) - // { - // bool paramMatch = true; - // ParameterInfo[] candidateParams = null; - // Type[] genericArgs = candidate.GetGenericArguments(); - // Type sourceParameterType = null; - - // if (genericArgs.Length != typeArguments.Length) - // { - // continue; - // } - - // // Since we can't just get the correct MethodInfo from Reflection, - // // we will just match the number of parameters, their order, and their type - // var methodCandidate = candidate; - // candidateParams = methodCandidate.GetParameters(); - - // if (candidateParams.Length != parameterTypes.Length) - // { - // continue; - // } - - // // Exact binding - // if ((bindingFlags & BindingFlags.ExactBinding) != 0) - // { - // int i = 0; - - // foreach (ParameterInfo candidateParam in candidateParams) - // { - // sourceParameterType = parameterTypes[i++]; - - // if (candidateParam.ParameterType.ContainsGenericParameters) - // { - // // Since we have a generic parameter here, just make sure the IsArray matches. - // if (candidateParam.ParameterType.IsArray != sourceParameterType.IsArray) - // { - // paramMatch = false; - // break; - // } - // } - // else - // { - // if (candidateParam.ParameterType != sourceParameterType) - // { - // paramMatch = false; - // break; - // } - // } - // } - - // if (paramMatch) - // { - // methodCandidates.AddLast(methodCandidate); - // continue; - // } - // } - // else - // { - // methodCandidates.AddLast(methodCandidate); - // } - // } - - // return methodCandidates; - //} - - #endregion - } -} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/Utilities/PrivateObjectExtensions.cs b/test/TensorFlowNET.UnitTest/Utilities/PrivateObjectExtensions.cs deleted file mode 100644 index fc196304..00000000 --- a/test/TensorFlowNET.UnitTest/Utilities/PrivateObjectExtensions.cs +++ /dev/null @@ -1,314 +0,0 @@ -// -// Copyright (c) 2019 cactuaroid All Rights Reserved -// -// -// Released under the MIT license -// https://github.com/cactuaroid/PrivateObjectExtensions -// - -using Microsoft.VisualStudio.TestTools.UnitTesting; -using System.Linq; -using System.Reflection; - -namespace System -{ - /// - /// Extension methods for PrivateObject - /// - public static class PrivateObjectExtensions - { - private static readonly BindingFlags Static = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly | BindingFlags.Static; - private static readonly BindingFlags Instance = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly | BindingFlags.Instance; - - /// - /// Get from private (and any other) field/property. - /// If the real type of specified object doesn't contain the specified field/property, - /// base types are searched automatically. - /// - /// The object to get from - /// The name of the field/property - /// The object got from the field/property - /// 'name' is not found. - /// Arguments contain null. - public static object GetPrivate(this object obj, string name) - { - if (obj == null) { throw new ArgumentNullException("obj"); } - - return GetPrivate(obj, name, obj.GetType(), null); - } - - /// - /// Get from private (and any other) field/property. - /// If the real type of specified object doesn't contain the specified field/property, - /// base types are searched automatically. - /// - /// The type of the field/property - /// The object to get from - /// The name of the field/property - /// The object got from the field/property - /// 'name' is not found. - /// Arguments contain null. - public static T GetPrivate(this object obj, string name) - { - if (obj == null) { throw new ArgumentNullException("obj"); } - - return (T)GetPrivate(obj, name, obj.GetType(), typeof(T)); - } - - /// - /// Get from private (and any other) field/property with assuming the specified object as specified type. - /// If the specified type doesn't contain the specified field/property, - /// base types are searched automatically. - /// - /// The object to get from - /// The name of the field/property - /// The type of 'obj' for seaching member starting from. Real type of 'obj' is ignored. - /// The object got from the field/property - /// 'name' is not found. - /// 'objType' is not assignable from 'obj'. - /// Arguments contain null. - public static object GetPrivate(this object obj, string name, Type objType) - { - return GetPrivate(obj, name, objType, null); - } - - /// - /// Get from private (and any other) field/property with assuming the specified object as specified type. - /// If the specified type doesn't contain the specified field/property, - /// base types are searched automatically. - /// - /// The type of the field/property - /// The object to get from - /// The name of the field/property - /// The type of 'obj' for seaching member starting from. Real type of 'obj' is ignored. - /// The object got from the field/property - /// 'name' is not found. - /// 'objType' is not assignable from 'obj'. - /// Arguments contain null. - public static T GetPrivate(this object obj, string name, Type objType) - { - return (T)GetPrivate(obj, name, objType, typeof(T)); - } - - private static object GetPrivate(object obj, string name, Type objType, Type memberType) - { - if (obj == null) { throw new ArgumentNullException("obj"); } - if (name == null) { throw new ArgumentNullException("name"); } - if (string.IsNullOrWhiteSpace(name)) { throw new ArgumentException("name is empty or white-space.", "name"); } - if (objType == null) { throw new ArgumentNullException("objType"); } - if (!objType.IsAssignableFrom(obj.GetType())) { throw new ArgumentException($"{objType} is not assignable from {obj.GetType()}.", "objType"); } - - bool memberTypeMatching(Type actualType) => actualType == memberType; - - if (TryFindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, Instance, out var ownerType)) - { - return new PrivateObject(obj, new PrivateType(ownerType)).GetFieldOrProperty(name); - } - else if (TryFindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, Static, out ownerType)) - { - return new PrivateType(ownerType).GetStaticFieldOrProperty(name); - } - - throw new ArgumentException(((memberType != null) ? memberType + " " : "") + name + " is not found."); - } - - /// - /// Get from private (and any other) static field/property. - /// - /// The type to get from - /// The name of the static field/property - /// The object got from the static field/property - /// 'name' is not found. - /// Arguments contain null. - public static object GetPrivate(this Type type, string name) - { - return GetPrivate(type, name, null); - } - - /// - /// Get from private (and any other) static field/property. - /// - /// The type of the field/property - /// The type to get from - /// The name of the static field/property - /// The object got from the static field/property - /// 'name' is not found. - /// Arguments contain null. - public static T GetPrivate(this Type type, string name) - { - return (T)GetPrivate(type, name, typeof(T)); - } - - private static object GetPrivate(this Type type, string name, Type memberType) - { - if (type == null) { throw new ArgumentNullException("type"); } - if (name == null) { throw new ArgumentNullException("name"); } - if (string.IsNullOrWhiteSpace(name)) { throw new ArgumentException("name is empty or white-space.", "name"); } - - bool memberTypeMatching(Type actualType) => actualType == memberType; - - if (type.ContainsFieldOrProperty(name, memberType, memberTypeMatching, Static)) - { - return new PrivateType(type).GetStaticFieldOrProperty(name); - } - - throw new ArgumentException(((memberType != null) ? memberType + " " : "") + name + " is not found."); - } - - /// - /// Set to private (and any other) field/property. - /// If the real type of specified object doesn't contain the specified field/property, - /// base types are searched automatically. - /// - /// The object to set to - /// The name of the field/property - /// The value to set for 'name' - /// 'name' is not found. - /// Arguments contain null. - public static void SetPrivate(this object obj, string name, T value) - { - if (obj == null) { throw new ArgumentNullException("obj"); } - - SetPrivate(obj, name, value, obj.GetType()); - } - - /// - /// Set to private (and any other) field/property with assuming the specified object as specified type. - /// If the specified type doesn't contain the specified field/property, - /// base types are searched automatically. - /// - /// The object to set to - /// The name of the field/property - /// The value to set for 'name' - /// The type of 'obj' for seaching member starting from. Real type of 'obj' is ignored. - /// 'name' is not found. - /// 'objType' is not assignable from 'obj'. - /// Arguments contain null. - public static void SetPrivate(this object obj, string name, T value, Type objType) - { - if (obj == null) { throw new ArgumentNullException("obj"); } - if (name == null) { throw new ArgumentNullException("name"); } - if (string.IsNullOrWhiteSpace(name)) { throw new ArgumentException("name is empty or white-space.", "name"); } - if (value == null) { throw new ArgumentNullException("value"); } - if (objType == null) { throw new ArgumentNullException("objType"); } - if (!objType.IsAssignableFrom(obj.GetType())) { throw new ArgumentException($"{objType} is not assignable from {obj.GetType()}.", "objType"); } - - if (TrySetPrivate(obj, name, value, objType)) { return; } - - // retry for the case of getter only property - if (TrySetPrivate(obj, GetBackingFieldName(name), value, objType)) { return; } - - throw new ArgumentException($"{typeof(T)} {name} is not found."); - } - - private static bool TrySetPrivate(object obj, string name, T value, Type objType) - { - var memberType = typeof(T); - bool memberTypeMatching(Type actualType) => actualType.IsAssignableFrom(memberType); - - try - { - if (TryFindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, Instance, out var ownerType)) - { - new PrivateObject(obj, new PrivateType(ownerType)).SetFieldOrProperty(name, value); - return true; - } - else if (TryFindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, Static, out ownerType)) - { - new PrivateType(ownerType).SetStaticFieldOrProperty(name, value); - return true; - } - } - catch (MissingMethodException) - { - // When getter only property name is given, the property is found but fails to set. - return false; - } - - return false; - } - - /// - /// Set to private (and any other) static field/property. - /// - /// The type to set to - /// The name of the field/property - /// The value to set for 'name' - /// 'name' is not found. - /// Arguments contain null. - public static void SetPrivate(this Type type, string name, T value) - { - if (type == null) { throw new ArgumentNullException("type"); } - if (name == null) { throw new ArgumentNullException("name"); } - if (string.IsNullOrWhiteSpace(name)) { throw new ArgumentException("name is empty or white-space.", "name"); } - - if (TrySetPrivate(type, name, value)) { return; } - - // retry for the case of getter only property - if (TrySetPrivate(type, GetBackingFieldName(name), value)) { return; } - - throw new ArgumentException($"{typeof(T)} {name} is not found."); - } - - private static bool TrySetPrivate(this Type type, string name, T value) - { - var memberType = typeof(T); - bool memberTypeMatching(Type actualType) => actualType.IsAssignableFrom(memberType); - - try - { - if (type.ContainsFieldOrProperty(name, memberType, memberTypeMatching, Static)) - { - new PrivateType(type).SetStaticFieldOrProperty(name, value); - return true; - } - } - catch (MissingMethodException) - { - // When getter only property name is given, the property is found but fails to set. - return false; - } - - return false; - } - - private static string GetBackingFieldName(string propertyName) - => $"<{propertyName}>k__BackingField"; // generated backing field name - - private static bool TryFindFieldOrPropertyOwnerType(Type objType, string name, Type memberType, Func memberTypeMatching, BindingFlags bindingFlag, out Type ownerType) - { - ownerType = FindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, bindingFlag); - - return (ownerType != null); - } - - private static Type FindFieldOrPropertyOwnerType(Type objectType, string name, Type memberType, Func memberTypeMatching, BindingFlags bindingFlags) - { - if (objectType == null) { return null; } - - if (objectType.ContainsFieldOrProperty(name, memberType, memberTypeMatching, bindingFlags)) - { - return objectType; - } - - return FindFieldOrPropertyOwnerType(objectType.BaseType, name, memberType, memberTypeMatching, bindingFlags); - } - - private static bool ContainsFieldOrProperty(this Type objectType, string name, Type memberType, Func memberTypeMatching, BindingFlags bindingFlags) - { - var fields = objectType - .GetFields(bindingFlags) - .Select((x) => new { Type = x.FieldType, Member = x as MemberInfo }); - - var properties = objectType - .GetProperties(bindingFlags) - .Select((x) => new { Type = x.PropertyType, Member = x as MemberInfo }); - - var members = fields.Concat(properties); - - return members.Any((actual) => - (memberType == null || memberTypeMatching.Invoke(actual.Type)) - && actual.Member.Name == name); - } - } -} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/SwitchTestCase.cs b/test/TensorFlowNET.UnitTest/control_flow_ops_test/SwitchTestCase.cs deleted file mode 100644 index b0597bdf..00000000 --- a/test/TensorFlowNET.UnitTest/control_flow_ops_test/SwitchTestCase.cs +++ /dev/null @@ -1,172 +0,0 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; - -namespace TensorFlowNET.UnitTest.control_flow_ops_test -{ - /// - /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py - /// - [TestClass] - public class SwitchTestCase : PythonTest - { - - [Ignore("TODO")] - [TestMethod] - public void testResourceReadInLoop() - { - - //var embedding_matrix = variable_scope.get_variable( - //"embedding_matrix", initializer: new double[,] { { 2.0 }, { 3.0 } }, use_resource: true); - - /* - Tensor cond(Tensor it, Tensor _) - { - return it < 5; - } - */ - - // TODO: below code doesn't compile - //(Tensor, Tensor) body(Tensor it, Tensor cost) - //{ - // var embedding = embedding_ops.embedding_lookup(embedding_matrix, new int[]{0}); - // cost += math_ops.reduce_sum(embedding); - // return (it + 1, cost); - //} - //var (_, cost1) = control_flow_ops.while_loop( - // cond, body, new[] - // { - // constant_op.constant(0), - // constant_op.constant(0.0) - // }); - //with(this.cached_session(), sess => - //{ - // self.evaluate(variables.global_variables_initializer()); - // self.assertAllEqual(10.0, self.evaluate(cost1)); - //}); - } - - - [Ignore("TODO")] - [TestMethod] - public void testIndexedSlicesGradientInCondInWhileLoop() - { - doTestIndexedSlicesGradientInCondInWhileLoop(use_resource: false); - } - - [Ignore("TODO")] - [TestMethod] - public void testIndexedSlicesGradientInCondInWhileLoopResource() - { - doTestIndexedSlicesGradientInCondInWhileLoop(use_resource: true); - } - - private void doTestIndexedSlicesGradientInCondInWhileLoop(bool use_resource = false) - { - //def doTestIndexedSlicesGradientInCondInWhileLoop(self, use_resource=False): - // embedding_matrix = variable_scope.get_variable( - // "embedding_matrix", [5, 5], - // initializer=init_ops.random_normal_initializer(), - // use_resource=use_resource) - - // def cond(it, _): - // return it < 5 - - // def body(it, cost): - // embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) - // cost = control_flow_ops.cond( - // math_ops.equal(it, 3), lambda: math_ops.square(cost), - // (lambda: cost + math_ops.reduce_sum(embedding))) - // return it + 1, cost - - // _, cost = control_flow_ops.while_loop( - // cond, body, [constant_op.constant(0), - // constant_op.constant(0.0)]) - - // dynamic_grads = gradients_impl.gradients(cost, [embedding_matrix])[0] - // dynamic_grads = math_ops.segment_sum(dynamic_grads.values, - // dynamic_grads.indices) - - // embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) - // static = math_ops.square( - // math_ops.reduce_sum(embedding) + math_ops.reduce_sum(embedding) + - // math_ops.reduce_sum(embedding)) + math_ops.reduce_sum(embedding) - // static_grads = gradients_impl.gradients(static, [embedding_matrix])[0] - // static_grads = math_ops.segment_sum(static_grads.values, - // static_grads.indices) - - // with self.cached_session(): - // self.evaluate(variables.global_variables_initializer()) - // self.assertAllEqual(*self.evaluate([static_grads, dynamic_grads])) - } - - [Ignore("TODO")] - [TestMethod] - public void testIndexedSlicesWithShapeGradientInWhileLoop() - { - //@test_util.run_v1_only("b/120545219") - //def testIndexedSlicesWithShapeGradientInWhileLoop(self): - // for dtype in [dtypes.float32, dtypes.float64]: - // with self.cached_session() as sess: - // num_steps = 9 - - // inputs = array_ops.placeholder(dtype=dtype, shape=[num_steps]) - // initial_outputs = tensor_array_ops.TensorArray( - // dtype=dtype, size=num_steps) - // initial_i = constant_op.constant(0, dtype=dtypes.int32) - - // def cond(i, _): - // return i < num_steps # pylint: disable=cell-var-from-loop - - // def body(i, outputs): - // x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop - // outputs = outputs.write(i, x) - // return i + 1, outputs - - // _, outputs = control_flow_ops.while_loop(cond, body, - // [initial_i, initial_outputs]) - - // outputs = math_ops.reduce_sum(outputs.stack()) - // r = gradients_impl.gradients([outputs], [inputs])[0] - // grad_wr_inputs = ops.convert_to_tensor(r) - // o, grad = sess.run([outputs, grad_wr_inputs], - // feed_dict={inputs: [4, 6, 0, 7, 0, 0, 1, 2, 0]}) - // self.assertEquals(o, 20) - // self.assertAllEqual(grad, [1] * num_steps) - - } - - [Ignore("TODO")] - [TestMethod] - public void testIndexedSlicesWithDynamicShapeGradientInWhileLoop() - { - //@test_util.run_v1_only("b/120545219") - //def testIndexedSlicesWithDynamicShapeGradientInWhileLoop(self): - // for dtype in [dtypes.float32, dtypes.float64]: - // with self.cached_session() as sess: - // inputs = array_ops.placeholder(dtype=dtype) - // initial_outputs = tensor_array_ops.TensorArray( - // dtype=dtype, dynamic_size=True, size=1) - // initial_i = constant_op.constant(0, dtype=dtypes.int32) - - // def cond(i, _): - // return i < array_ops.size(inputs) # pylint: disable=cell-var-from-loop - - // def body(i, outputs): - // x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop - // outputs = outputs.write(i, x) - // return i + 1, outputs - - // _, outputs = control_flow_ops.while_loop(cond, body, - // [initial_i, initial_outputs]) - - // outputs = math_ops.reduce_sum(outputs.stack()) - // r = gradients_impl.gradients([outputs], [inputs])[0] - // grad_wr_inputs = ops.convert_to_tensor(r) - // o, grad = sess.run([outputs, grad_wr_inputs], - // feed_dict={inputs: [1, 3, 2]}) - // self.assertEquals(o, 6) - // self.assertAllEqual(grad, [1] * 3) - - } - - } -}