From 6fe6057ff4a9c66497d79ebb887cb329660396ce Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 9 Jan 2021 08:06:44 -0600 Subject: [PATCH] Adjust unit test project. --- TensorFlow.NET.sln | 38 +- .../Tensorflow.Console.csproj | 1 + src/TensorFlowNET.Core/APIs/tf.graph.cs | 11 +- .../Contexts/Context.AutoMode.cs | 14 +- src/TensorFlowNET.Core/Contexts/Context.cs | 12 +- .../Contexts/ContextSwitchStack.cs | 10 + src/TensorFlowNET.Core/Data/MapDataset.cs | 6 +- .../Functions/ConcreteFunction.cs | 21 +- .../Functions/EagerDefinedFunction.cs | 1 - .../Functions/TapeGradientFunctions.cs | 3 +- src/TensorFlowNET.Core/Graphs/AutoGraph.cs | 6 +- .../Graphs/AutoGraphAttribute.cs | 5 +- .../Graphs/DefaultGraphStack.cs | 56 +- src/TensorFlowNET.Core/Graphs/FuncGraph.cs | 7 +- src/TensorFlowNET.Core/Graphs/Graph.cs | 8 +- .../Graphs/SubGraphUtility.cs | 2 +- .../Operations/OpDefLibrary.cs | 7 +- src/TensorFlowNET.Core/ops.cs | 20 + src/TensorFlowNET.Core/ops.threading.cs | 18 +- src/TensorFlowNET.Keras/BackendImpl.cs | 6 +- .../Layer.FunctionalConstructionCall.cs | 2 +- .../Layers/Core/InputLayer.cs | 3 +- .../Tensorflow.Recommenders.csproj | 1 + src/TensorFlowNET.Text/Tensorflow.Text.csproj | 5 + .../CApiAttributesTestcs.cs | 3 +- .../CApiColocationTest.cs | 3 +- .../CApiFunctionTest.cs | 5 +- .../CApiGradientsTest.cs | 4 +- .../CApiTest.cs | 6 +- .../CSession.cs | 3 +- .../Eager/CApi.Eager.Context.cs | 3 +- .../Eager/CApi.Eager.Execute_MatMul_CPU.cs | 3 +- .../CApi.Eager.OpGetInputAndOutputLengths.cs | 3 +- ...pi.Eager.OpInferMixedTypeInputListAttrs.cs | 3 +- .../Eager/CApi.Eager.TensorHandle.cs | 3 +- .../Eager/CApi.Eager.TensorHandleDevices.cs | 3 +- .../Eager/CApi.Eager.Variables.cs | 3 +- .../Eager/CApi.Eager.cs | 3 +- .../Eager/GradientEagerTest.cs | 5 +- .../GraphBuildTest.cs | 3 +- .../GraphTest.cs | 4 +- .../Sessions/SessionTest.cs | 74 + .../Tensorflow.Native.UnitTest.csproj | 36 + .../Tensors/TensorTest.cs | 204 ++ .../c_test_util.cs | 4 +- .../Basics/SessionTest.cs | 73 +- .../Basics/TensorTest.cs | 223 +- .../EagerModeTestBase.cs | 30 + .../EnforcedSinglethreadingTests.cs | 4 +- .../GraphModeTestBase.cs | 2 + .../ManagedAPI/ActivationFunctionTest.cs | 2 +- .../ManagedAPI/BitwiseApiTest.cs | 2 +- .../ManagedAPI/FunctionApiTest.cs | 2 +- .../ManagedAPI/MathApiTest.cs | 2 +- .../ManagedAPI/TFNetApiTest.cs | 35 - .../ManagedAPI/ZeroFractionTest.cs | 86 - .../MultithreadingTests.cs | 1 + .../Utilities/TestHelper.cs | 15 +- .../control_flow_ops_test.py | 1059 ------ .../nest_test/NestTest.cs | 875 ----- .../nest_test/nest_test.py | 883 ----- .../ops_test/ControlDependenciesTest.cs | 249 -- .../ops_test/CreateOpFromTfOperationTest.cs | 222 -- .../ops_test/GraphTest.cs | 196 -- .../ops_test/ops_test_r1.13.py | 3014 ----------------- .../python/train_saver.py | 26 - 66 files changed, 553 insertions(+), 7089 deletions(-) rename test/{TensorFlowNET.UnitTest/NativeAPI => TensorFlowNET.Native.UnitTest}/CApiAttributesTestcs.cs (98%) rename test/{TensorFlowNET.UnitTest/NativeAPI => TensorFlowNET.Native.UnitTest}/CApiColocationTest.cs (98%) rename test/{TensorFlowNET.UnitTest/NativeAPI => TensorFlowNET.Native.UnitTest}/CApiFunctionTest.cs (99%) rename test/{TensorFlowNET.UnitTest/NativeAPI => TensorFlowNET.Native.UnitTest}/CApiGradientsTest.cs (99%) rename test/{TensorFlowNET.UnitTest/NativeAPI => TensorFlowNET.Native.UnitTest}/CApiTest.cs (98%) rename test/{TensorFlowNET.UnitTest/NativeAPI => TensorFlowNET.Native.UnitTest}/CSession.cs (98%) rename test/{TensorFlowNET.UnitTest/NativeAPI => TensorFlowNET.Native.UnitTest}/Eager/CApi.Eager.Context.cs (95%) rename test/{TensorFlowNET.UnitTest/NativeAPI => TensorFlowNET.Native.UnitTest}/Eager/CApi.Eager.Execute_MatMul_CPU.cs (97%) rename test/{TensorFlowNET.UnitTest/NativeAPI => TensorFlowNET.Native.UnitTest}/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs (97%) rename test/{TensorFlowNET.UnitTest/NativeAPI => TensorFlowNET.Native.UnitTest}/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs (97%) rename test/{TensorFlowNET.UnitTest/NativeAPI => TensorFlowNET.Native.UnitTest}/Eager/CApi.Eager.TensorHandle.cs (93%) rename test/{TensorFlowNET.UnitTest/NativeAPI => TensorFlowNET.Native.UnitTest}/Eager/CApi.Eager.TensorHandleDevices.cs (98%) rename test/{TensorFlowNET.UnitTest/NativeAPI => TensorFlowNET.Native.UnitTest}/Eager/CApi.Eager.Variables.cs (97%) rename test/{TensorFlowNET.UnitTest/NativeAPI => TensorFlowNET.Native.UnitTest}/Eager/CApi.Eager.cs (99%) rename test/{TensorFlowNET.UnitTest/NativeAPI => TensorFlowNET.Native.UnitTest}/Eager/GradientEagerTest.cs (96%) rename test/{TensorFlowNET.UnitTest/NativeAPI => TensorFlowNET.Native.UnitTest}/GraphBuildTest.cs (94%) rename test/{TensorFlowNET.UnitTest/NativeAPI => TensorFlowNET.Native.UnitTest}/GraphTest.cs (99%) create mode 100644 test/TensorFlowNET.Native.UnitTest/Sessions/SessionTest.cs create mode 100644 test/TensorFlowNET.Native.UnitTest/Tensorflow.Native.UnitTest.csproj create mode 100644 test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs rename test/{TensorFlowNET.UnitTest/NativeAPI => TensorFlowNET.Native.UnitTest}/c_test_util.cs (98%) delete mode 100644 test/TensorFlowNET.UnitTest/ManagedAPI/TFNetApiTest.cs delete mode 100644 test/TensorFlowNET.UnitTest/ManagedAPI/ZeroFractionTest.cs delete mode 100644 test/TensorFlowNET.UnitTest/control_flow_ops_test/control_flow_ops_test.py delete mode 100644 test/TensorFlowNET.UnitTest/nest_test/NestTest.cs delete mode 100644 test/TensorFlowNET.UnitTest/nest_test/nest_test.py delete mode 100644 test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs delete mode 100644 test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs delete mode 100644 test/TensorFlowNET.UnitTest/ops_test/GraphTest.cs delete mode 100644 test/TensorFlowNET.UnitTest/ops_test/ops_test_r1.13.py delete mode 100644 test/TensorFlowNET.UnitTest/python/train_saver.py diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln index 3173c1a9..ffc39b5a 100644 --- a/TensorFlow.NET.sln +++ b/TensorFlow.NET.sln @@ -17,6 +17,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Text", "src\Tens EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Recommenders", "src\TensorFlowNET.Recommenders\Tensorflow.Recommenders.csproj", "{F17AAECB-960A-4E18-A270-BAD776F0E55B}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Native.UnitTest", "test\TensorFlowNET.Native.UnitTest\Tensorflow.Native.UnitTest.csproj", "{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -107,8 +109,8 @@ Global {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x86.Build.0 = Release|Any CPU {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|Any CPU.Build.0 = Debug|Any CPU - {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x64.ActiveCfg = Debug|Any CPU - {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x64.Build.0 = Debug|Any CPU + {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x64.ActiveCfg = Debug|x64 + {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x64.Build.0 = Debug|x64 {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x86.ActiveCfg = Debug|Any CPU {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x86.Build.0 = Debug|Any CPU {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU @@ -155,8 +157,8 @@ Global {49D71826-C03D-4FA7-9BAC-22C1327E65CF}.Release|x86.Build.0 = Release|Any CPU {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|Any CPU.Build.0 = Debug|Any CPU - {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x64.ActiveCfg = Debug|Any CPU - {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x64.Build.0 = Debug|Any CPU + {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x64.ActiveCfg = Debug|x64 + {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x64.Build.0 = Debug|x64 {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x86.ActiveCfg = Debug|Any CPU {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x86.Build.0 = Debug|Any CPU {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU @@ -179,8 +181,8 @@ Global {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Release|x86.Build.0 = Release|Any CPU {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|Any CPU.Build.0 = Debug|Any CPU - {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x64.ActiveCfg = Debug|Any CPU - {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x64.Build.0 = Debug|Any CPU + {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x64.ActiveCfg = Debug|x64 + {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x64.Build.0 = Debug|x64 {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x86.ActiveCfg = Debug|Any CPU {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x86.Build.0 = Debug|Any CPU {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU @@ -201,6 +203,30 @@ Global {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|x64.Build.0 = Release|Any CPU {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|x86.ActiveCfg = Release|Any CPU {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|x86.Build.0 = Release|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|Any CPU.Build.0 = Debug|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|x64.ActiveCfg = Debug|x64 + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|x64.Build.0 = Debug|x64 + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|x86.ActiveCfg = Debug|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|x86.Build.0 = Debug|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|x64.Build.0 = Debug|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|x86.ActiveCfg = Debug|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|x86.Build.0 = Debug|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|Any CPU.Build.0 = Debug|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|x64.ActiveCfg = Debug|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|x64.Build.0 = Debug|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|x86.ActiveCfg = Debug|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|x86.Build.0 = Debug|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|Any CPU.ActiveCfg = Release|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|Any CPU.Build.0 = Release|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x64.ActiveCfg = Release|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x64.Build.0 = Release|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x86.ActiveCfg = Release|Any CPU + {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/src/TensorFlowNET.Console/Tensorflow.Console.csproj b/src/TensorFlowNET.Console/Tensorflow.Console.csproj index d2b64e80..11dda95f 100644 --- a/src/TensorFlowNET.Console/Tensorflow.Console.csproj +++ b/src/TensorFlowNET.Console/Tensorflow.Console.csproj @@ -5,6 +5,7 @@ netcoreapp3.1 Tensorflow Tensorflow + AnyCPU;x64 diff --git a/src/TensorFlowNET.Core/APIs/tf.graph.cs b/src/TensorFlowNET.Core/APIs/tf.graph.cs index 7c0e7585..eec7f7f8 100644 --- a/src/TensorFlowNET.Core/APIs/tf.graph.cs +++ b/src/TensorFlowNET.Core/APIs/tf.graph.cs @@ -28,17 +28,10 @@ namespace Tensorflow => ops.reset_default_graph(); public Graph get_default_graph() - { - return ops.get_default_graph(); - } + => ops.get_default_graph(); - /// - /// Equivalent to but does not create a new graph if it there is none. - /// public Graph peak_default_graph() - { - return ops.default_graph_stack.peak_controller(); - } + => ops.peak_default_graph(); /// /// Creates a new graph. diff --git a/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs b/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs index 3626e9df..1a5c00d2 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs @@ -37,19 +37,7 @@ namespace Tensorflow.Contexts if (shouldRunInEager) return eagerAction(); else - { - if (executing_eagerly()) - { - graph_mode(); - var result = graphAction(); - restore_mode(); - return result; - } - else - { - return graphAction(); - } - } + return graphAction(); } // [DebuggerStepThrough] diff --git a/src/TensorFlowNET.Core/Contexts/Context.cs b/src/TensorFlowNET.Core/Contexts/Context.cs index 17e85306..43564fdb 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.cs @@ -80,9 +80,12 @@ namespace Tensorflow.Contexts /// Checks whether the current thread has eager execution enabled. /// /// - [DebuggerStepThrough] + // [DebuggerStepThrough] public bool executing_eagerly() { + if(context_switches.Count() == 0) + tf.enable_eager_execution(); + return context_switches.Current().EagerMode; } @@ -103,11 +106,16 @@ namespace Tensorflow.Contexts public void restore_mode() { context_switches.Pop(); + tf.get_default_graph(); } public void reset_context() { - c_api.TFE_ContextClearCaches(_handle); + ops.reset_uid(); + ops.reset_default_graph(); + context_switches.Clear(); + if (_handle != null) + c_api.TFE_ContextClearCaches(_handle); } public void Dispose() diff --git a/src/TensorFlowNET.Core/Contexts/ContextSwitchStack.cs b/src/TensorFlowNET.Core/Contexts/ContextSwitchStack.cs index 84bc3889..27704b3e 100644 --- a/src/TensorFlowNET.Core/Contexts/ContextSwitchStack.cs +++ b/src/TensorFlowNET.Core/Contexts/ContextSwitchStack.cs @@ -40,11 +40,21 @@ namespace Tensorflow.Contexts }); } + public void Clear() + { + stack.Clear(); + } + public void Pop() { stack.Pop(); } + public int Count() + { + return stack.Count; + } + public ContextSwitch Current() { return stack.Peek(); diff --git a/src/TensorFlowNET.Core/Data/MapDataset.cs b/src/TensorFlowNET.Core/Data/MapDataset.cs index 09513f94..50176cdc 100644 --- a/src/TensorFlowNET.Core/Data/MapDataset.cs +++ b/src/TensorFlowNET.Core/Data/MapDataset.cs @@ -15,11 +15,13 @@ namespace Tensorflow bool preserve_cardinality = false, bool use_legacy_function = false) : base(input_dataset) { - using var func = new ConcreteFunction($"autograph_{map_func.Method.Name}"); + using var func = new ConcreteFunction($"{map_func.Method.Name}_{Guid.NewGuid()}"); + func.Enter(); var input = tf.placeholder(input_dataset.element_spec[0].dtype); var output = map_func(input); func.ToGraph(input, output); - + func.Exit(); + structure = func.OutputStructure; variant_tensor = ops.map_dataset(input_dataset.variant_tensor, diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index 45fa3420..b1d932b6 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -34,7 +34,6 @@ namespace Tensorflow.Functions public ConcreteFunction(string name) { func_graph = new FuncGraph(name); - func_graph.as_default(); } public ConcreteFunction(FuncGraph graph, Dictionary attrs = null) @@ -46,7 +45,7 @@ namespace Tensorflow.Functions public ConcreteFunction(Func func, TF_DataType dtype) { - string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; + string func_name = $"{func.Method.Name}_{Guid.NewGuid()}"; // IntPtr func_handle; using var graph = new FuncGraph(func_name); @@ -59,11 +58,12 @@ namespace Tensorflow.Functions new[] { input }, new[] { output }, null); + graph.Exit(); } public ConcreteFunction(Func func, TF_DataType dtype) { - string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; + string func_name = $"{func.Method.Name}_{Guid.NewGuid()}"; // IntPtr func_handle; using var graph = new FuncGraph(func_name); @@ -79,12 +79,13 @@ namespace Tensorflow.Functions new[] { input }, new[] { output.variant_tensor }, null); + graph.Exit(); } public ConcreteFunction(Func func, TF_DataType[] dtypes, TensorShape[] shapes) { - string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; + string func_name = $"{func.Method.Name}_{Guid.NewGuid()}"; // IntPtr func_handle; using var graph = new FuncGraph(func_name); @@ -103,6 +104,7 @@ namespace Tensorflow.Functions new[] { input1, input2, input3 }, new[] { outputs.Item1, outputs.Item2 }, null); + graph.Exit(); } public void ToGraph(Tensors inputs, Tensors outputs) @@ -112,10 +114,19 @@ namespace Tensorflow.Functions inputs, outputs, null); - OutputStructure = outputs.Select(x => x.ToTensorSpec()).ToArray(); } + public void Enter() + { + func_graph.as_default(); + } + + public void Exit() + { + func_graph.Exit(); + } + public Tensors Invoke(Tensors inputs) { var forward_backward = SelectForwardAndBackwardFunctions(inputs, 1, tf.Context.executing_eagerly()); diff --git a/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs b/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs index e0896253..f615f6a4 100644 --- a/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs +++ b/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs @@ -26,7 +26,6 @@ namespace Tensorflow.Functions var output_names = new string[0]; _func_graph = new FuncGraph(graph, name, attrs); - _func_graph.as_default(); _func_graph.ToGraph(operations, inputs, outputs, output_names); } diff --git a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs index 5dd1a8ae..13c57e86 100644 --- a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs +++ b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs @@ -84,7 +84,7 @@ namespace Tensorflow.Functions } var gradients_wrt_outputs = new List(); - var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{ops.uid()}"); + var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"); backwards_graph.as_default(); foreach (var output in trainable_outputs) gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape)); @@ -101,6 +101,7 @@ namespace Tensorflow.Functions if (!_func_graph.Outputs.Contains(capture)) _func_graph.Outputs.Add(capture); } + backwards_graph.Exit(); var forward_function_name = $"{_FORWARD_PREFIX}_{ops.uid()}"; var backward_function_attr = new Dictionary(); diff --git a/src/TensorFlowNET.Core/Graphs/AutoGraph.cs b/src/TensorFlowNET.Core/Graphs/AutoGraph.cs index 901cbd6f..78b19a6f 100644 --- a/src/TensorFlowNET.Core/Graphs/AutoGraph.cs +++ b/src/TensorFlowNET.Core/Graphs/AutoGraph.cs @@ -8,7 +8,7 @@ namespace Tensorflow.Graphs { public Func to_graph(Func func) { - string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; + string func_name = $"{func.Method.Name}_{Guid.NewGuid()}"; // IntPtr func_handle; using (var graph = new FuncGraph(func_name)) @@ -22,6 +22,7 @@ namespace Tensorflow.Graphs new[] { input }, new[] { output }, null); + graph.Exit(); } @@ -39,7 +40,7 @@ namespace Tensorflow.Graphs public Func to_graph(Func func) { - string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; + string func_name = $"{func.Method.Name}_{Guid.NewGuid()}"; // IntPtr func_handle; using (var graph = new FuncGraph(func_name)) @@ -54,6 +55,7 @@ namespace Tensorflow.Graphs new[] { input1, input2 }, new[] { output }, null); + graph.Exit(); } return (Tensor a, Tensor b) => diff --git a/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs b/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs index 3d04b8a3..010eb345 100644 --- a/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs +++ b/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs @@ -18,7 +18,7 @@ namespace Tensorflow.Graphs public override void OnEntry(MethodExecutionArgs args) { - func_name = $"autograph_{args.Instance.GetHashCode()}.{args.Method.Name}"; + func_name = $"{args.Method.Name}_{Guid.NewGuid()}"; if (functions.ContainsKey(func_name)) { @@ -34,6 +34,7 @@ namespace Tensorflow.Graphs // make function as an Operation by autograph // need to restore mode when exits function = new ConcreteFunction(func_name); + function.Enter(); // convert to Tensors if (args.Arguments[0] is Tensors inputs) @@ -68,6 +69,8 @@ namespace Tensorflow.Graphs } else function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), args.ReturnValue as Tensor); + + function.Exit(); // cache function. function.ReturnType = args.ReturnValue.GetType(); diff --git a/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs b/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs index 5407dcfe..0eee3cdb 100644 --- a/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs +++ b/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs @@ -25,63 +25,43 @@ namespace Tensorflow /// public class DefaultGraphStack { - private readonly List _stack = new List(); + private readonly Stack _stack = new Stack(); + Graph _global_default_graph; - public void set_controller(Graph @default) + public Graph get_default() { - if (!_stack.Exists(x => x.Graph == @default)) - _stack.Add(new StackModel { Graph = @default, IsDefault = true }); + if (_stack.Count > 0) + return _stack.Peek(); + else if (_global_default_graph != null) + return _global_default_graph; + else + _global_default_graph = new Graph(); - foreach (var s in _stack) - s.IsDefault = s.Graph == @default; + return _global_default_graph; } - public Graph get_controller() + public Graph get_controller(Graph g) { - if (_stack.Count == 0 || _stack.Count(x => x.IsDefault) == 0) - _stack.Add(new StackModel { Graph = tf.Graph(), IsDefault = true }); - for (var i = _stack.Count - 1; i >= 0; i--) - { - var x = _stack[i]; - if (x.IsDefault) - return x.Graph; - } - - throw new TensorflowException("Unable to find a default graph"); + _stack.Push(g); + return g; } public Graph peak_controller() { - if (_stack.Count == 0 || _stack.Count(x => x.IsDefault) == 0) + if (_stack.Count == 0) return null; - for (var i = _stack.Count - 1; i >= 0; i--) - { - var x = _stack[i]; - if (x.IsDefault) - return x.Graph; - } - - return null; + return _stack.Peek(); } - public bool remove(Graph g) + public void pop() { - if (_stack.Count == 0) - return false; - - var sm = _stack.Find(model => model.Graph == g); - return sm != null && _stack.Remove(sm); + _stack.Pop(); } public void reset() { _stack.Clear(); - } - - private class StackModel - { - public Graph Graph { get; set; } - public bool IsDefault { get; set; } + _global_default_graph = null; } } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs index 37f03267..797f81ad 100644 --- a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs +++ b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs @@ -94,8 +94,6 @@ namespace Tensorflow.Graphs // mark_as_return Outputs = outputs;// .Select(x => array_ops.identity(x)).ToArray(); - tf.Context.restore_mode(); - return func_handle; } @@ -247,9 +245,10 @@ namespace Tensorflow.Graphs return this; } - protected override void DisposeManagedResources() + public override void Exit() { - base.DisposeManagedResources(); + tf.Context.restore_mode(); + ops.pop_graph(); } } } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 81cb08c4..e3f14336 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -146,6 +146,7 @@ namespace Tensorflow /// /// Returns a context manager that makes this `Graph` the default graph. + /// Must call Exit() to pop graph /// /// public virtual Graph as_default() @@ -487,7 +488,7 @@ namespace Tensorflow protected override void DisposeManagedResources() { - ops.default_graph_stack.remove(this); + } protected override void DisposeUnmanagedResources(IntPtr handle) @@ -529,6 +530,11 @@ namespace Tensorflow return new TensorShape(dims.Select(x => (int)x).ToArray()); } + public virtual void Exit() + { + ops.pop_graph(); + } + string debugString = string.Empty; public override string ToString() { diff --git a/src/TensorFlowNET.Core/Graphs/SubGraphUtility.cs b/src/TensorFlowNET.Core/Graphs/SubGraphUtility.cs index 383fe8af..7c186f94 100644 --- a/src/TensorFlowNET.Core/Graphs/SubGraphUtility.cs +++ b/src/TensorFlowNET.Core/Graphs/SubGraphUtility.cs @@ -95,7 +95,7 @@ namespace Tensorflow.Graphs _copy_non_source(op, graph, op_map, base_graph); } - tf.Context.restore_mode(); + graph.Exit(); return op_map; } diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index c02de5e8..1b2dcd24 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -30,7 +30,7 @@ namespace Tensorflow public Operation _apply_op_helper(string op_type_name, string name = null, Dictionary keywords = null) { - var g = ops.get_default_graph(); + var g = ops._get_graph_from_inputs(keywords == null ? new object[0] : keywords.Values.ToArray()); var op_def = g.GetOpDef(op_type_name); // Default name if not specified. @@ -59,7 +59,8 @@ namespace Tensorflow var input_types = new List(); object values = null; - return tf_with(ops.name_scope(name), scope => + g.as_default(); + var ret_op = tf_with(ops.name_scope(name), scope => { var inferred_from = new Dictionary(); var base_types = new List(); @@ -249,6 +250,8 @@ namespace Tensorflow return op; }); + g.Exit(); + return ret_op; } private void _MaybeColocateWith(ITensorOrOperation[] inputs) diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 725a50b6..18dc4426 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -78,6 +78,21 @@ namespace Tensorflow return get_default_graph().get_collection_ref(key); } + public static Graph _get_graph_from_inputs(params object[] op_input_list) + { + var current_default_graph = get_default_graph(); + if (current_default_graph.building_function) + return current_default_graph; + + Graph graph = null; + foreach (var op_input in op_input_list) + { + if (op_input is Tensor op_input_tensor) + graph = graph ?? op_input_tensor.graph; + } + return graph ?? current_default_graph; + } + public static Graph _get_graph_from_inputs(Tensors op_input_list) => _get_graph_from_inputs(op_input_list: op_input_list, graph: null); @@ -337,6 +352,11 @@ namespace Tensorflow return Interlocked.Increment(ref uid_number); } + public static void reset_uid() + { + uid_number = -1; + } + public static void colocate_with(bool ignore_existing = false) { _colocate_with_for_gradient(null, null, ignore_existing); diff --git a/src/TensorFlowNET.Core/ops.threading.cs b/src/TensorFlowNET.Core/ops.threading.cs index e436cae0..f52dbcae 100644 --- a/src/TensorFlowNET.Core/ops.threading.cs +++ b/src/TensorFlowNET.Core/ops.threading.cs @@ -118,16 +118,10 @@ namespace Tensorflow /// /// public static Graph get_default_graph() - { - //return _default_graph_stack.get_default() - return default_graph_stack.get_controller(); - } + => default_graph_stack.get_default(); - public static Graph set_default_graph(Graph graph) - { - default_graph_stack.set_controller(graph); - return default_graph_stack.get_controller(); - } + public static Graph set_default_graph(Graph g) + => default_graph_stack.get_controller(g); /// /// Clears the default graph stack and resets the global default graph. @@ -147,5 +141,11 @@ namespace Tensorflow // "exit the nesting and create a new graph."); default_graph_stack.reset(); } + + public static Graph peak_default_graph() + => default_graph_stack.peak_controller(); + + public static void pop_graph() + => default_graph_stack.pop(); } } \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs index e24294fc..4c25f70a 100644 --- a/src/TensorFlowNET.Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Keras/BackendImpl.cs @@ -115,10 +115,10 @@ namespace Tensorflow.Keras public void reset_uids() => PER_GRAPH_LAYER_NAME_UIDS = new Dictionary>(); public void clear_session() { - ops.reset_default_graph(); + tf.Context.reset_context(); reset_uids(); ops.set_default_session(tf.Session(ops.get_default_graph())); - var phase = tf.placeholder_with_default(false, new int[] { }, name: "keras_learning_phase"); + // var phase = tf.placeholder_with_default(false, new int[] { }, name: "keras_learning_phase"); _GRAPH_LEARNING_PHASES = new Dictionary(); _GRAPH_LEARNING_PHASES[tf.get_default_graph()] = 0; } @@ -185,7 +185,7 @@ namespace Tensorflow.Keras return tensor_util.constant_value(outputs); var source_graph = outputs.graph; - using var exec_graph = _scratch_graph(); + var exec_graph = _scratch_graph(); var global_graph = get_graph(); if (source_graph == global_graph && exec_graph != global_graph) { diff --git a/src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs b/src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs index 25a7cd7a..699bec4c 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs @@ -46,7 +46,7 @@ namespace Tensorflow.Keras.Engine _set_mask_metadata(inputs, outputs, null); }); - tf.Context.restore_mode(); + graph.Exit(); return outputs; } diff --git a/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs b/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs index 54af955a..4e1e0e36 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs @@ -81,8 +81,9 @@ namespace Tensorflow.Keras.Layers sparse: args.Sparse, ragged: args.Ragged); + graph.Exit(); + isPlaceholder = true; - tf.Context.restore_mode(); } // Create an input node to add to self.outbound_node diff --git a/src/TensorFlowNET.Recommenders/Tensorflow.Recommenders.csproj b/src/TensorFlowNET.Recommenders/Tensorflow.Recommenders.csproj index ba47de36..8169ab66 100644 --- a/src/TensorFlowNET.Recommenders/Tensorflow.Recommenders.csproj +++ b/src/TensorFlowNET.Recommenders/Tensorflow.Recommenders.csproj @@ -5,6 +5,7 @@ 0.0.1 TensorFlow Recommenders is a library for building recommender system models using TensorFlow. LICENSE + AnyCPU;x64 diff --git a/src/TensorFlowNET.Text/Tensorflow.Text.csproj b/src/TensorFlowNET.Text/Tensorflow.Text.csproj index e687524c..1ce9c875 100644 --- a/src/TensorFlowNET.Text/Tensorflow.Text.csproj +++ b/src/TensorFlowNET.Text/Tensorflow.Text.csproj @@ -7,12 +7,17 @@ true 0.0.1 LICENSE + AnyCPU;x64 DEBUG;TRACE + + DEBUG;TRACE + + True diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/CApiAttributesTestcs.cs b/test/TensorFlowNET.Native.UnitTest/CApiAttributesTestcs.cs similarity index 98% rename from test/TensorFlowNET.UnitTest/NativeAPI/CApiAttributesTestcs.cs rename to test/TensorFlowNET.Native.UnitTest/CApiAttributesTestcs.cs index 0a18f63f..a22d02e0 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/CApiAttributesTestcs.cs +++ b/test/TensorFlowNET.Native.UnitTest/CApiAttributesTestcs.cs @@ -1,8 +1,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; -using Tensorflow; -namespace TensorFlowNET.UnitTest.NativeAPI +namespace Tensorflow.Native.UnitTest { /// /// tensorflow\c\c_api_test.cc diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/CApiColocationTest.cs b/test/TensorFlowNET.Native.UnitTest/CApiColocationTest.cs similarity index 98% rename from test/TensorFlowNET.UnitTest/NativeAPI/CApiColocationTest.cs rename to test/TensorFlowNET.Native.UnitTest/CApiColocationTest.cs index 9d892199..09b88e13 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/CApiColocationTest.cs +++ b/test/TensorFlowNET.Native.UnitTest/CApiColocationTest.cs @@ -1,9 +1,8 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; using System.Runtime.InteropServices; -using Tensorflow; -namespace TensorFlowNET.UnitTest.NativeAPI +namespace Tensorflow.Native.UnitTest { /// /// tensorflow\c\c_api_test.cc diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/CApiFunctionTest.cs b/test/TensorFlowNET.Native.UnitTest/CApiFunctionTest.cs similarity index 99% rename from test/TensorFlowNET.UnitTest/NativeAPI/CApiFunctionTest.cs rename to test/TensorFlowNET.Native.UnitTest/CApiFunctionTest.cs index 42d2a22c..570c6ae1 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/CApiFunctionTest.cs +++ b/test/TensorFlowNET.Native.UnitTest/CApiFunctionTest.cs @@ -2,10 +2,9 @@ using System; using System.Collections.Generic; using System.Linq; -using Tensorflow; -using static TensorFlowNET.UnitTest.c_test_util; +using static Tensorflow.Native.UnitTest.c_test_util; -namespace TensorFlowNET.UnitTest.NativeAPI +namespace Tensorflow.Native.UnitTest { /// /// tensorflow\c\c_api_function_test.cc diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/CApiGradientsTest.cs b/test/TensorFlowNET.Native.UnitTest/CApiGradientsTest.cs similarity index 99% rename from test/TensorFlowNET.UnitTest/NativeAPI/CApiGradientsTest.cs rename to test/TensorFlowNET.Native.UnitTest/CApiGradientsTest.cs index 77698f99..6db90aef 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/CApiGradientsTest.cs +++ b/test/TensorFlowNET.Native.UnitTest/CApiGradientsTest.cs @@ -1,11 +1,9 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using NumSharp; using System; -using Tensorflow; using Tensorflow.Util; -using Buffer = Tensorflow.Buffer; -namespace TensorFlowNET.UnitTest.NativeAPI +namespace Tensorflow.Native.UnitTest { /// /// tensorflow\c\c_api_test.cc diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs b/test/TensorFlowNET.Native.UnitTest/CApiTest.cs similarity index 98% rename from test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs rename to test/TensorFlowNET.Native.UnitTest/CApiTest.cs index c9366b5f..e8a9486f 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs +++ b/test/TensorFlowNET.Native.UnitTest/CApiTest.cs @@ -1,13 +1,11 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; -using Tensorflow; using Tensorflow.Device; using Tensorflow.Eager; -using Tensorflow.UnitTest; -namespace TensorFlowNET.UnitTest +namespace Tensorflow.Native.UnitTest { - public class CApiTest : GraphModeTestBase + public class CApiTest { protected static readonly TF_Code TF_OK = TF_Code.TF_OK; protected static readonly TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT; diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/CSession.cs b/test/TensorFlowNET.Native.UnitTest/CSession.cs similarity index 98% rename from test/TensorFlowNET.UnitTest/NativeAPI/CSession.cs rename to test/TensorFlowNET.Native.UnitTest/CSession.cs index 651bc853..c973e1b3 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/CSession.cs +++ b/test/TensorFlowNET.Native.UnitTest/CSession.cs @@ -1,10 +1,9 @@ using System; using System.Collections.Generic; using System.Linq; -using Tensorflow; using Tensorflow.Util; -namespace TensorFlowNET.UnitTest +namespace Tensorflow.Native.UnitTest { /// /// tensorflow\c\c_test_util.cc diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.Context.cs similarity index 95% rename from test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs rename to test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.Context.cs index ea737e3a..7628bbc2 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs +++ b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.Context.cs @@ -1,9 +1,8 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; -using Tensorflow; using Tensorflow.Device; using Tensorflow.Eager; -namespace TensorFlowNET.UnitTest.NativeAPI +namespace Tensorflow.Native.UnitTest.Eager { public partial class CApiEagerTest { diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.Execute_MatMul_CPU.cs similarity index 97% rename from test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs rename to test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.Execute_MatMul_CPU.cs index 441f9d27..e8c6844a 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs +++ b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.Execute_MatMul_CPU.cs @@ -1,10 +1,9 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; -using Tensorflow; using Tensorflow.Eager; using static Tensorflow.Binding; -namespace TensorFlowNET.UnitTest.NativeAPI +namespace Tensorflow.Native.UnitTest.Eager { public partial class CApiEagerTest { diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs similarity index 97% rename from test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs rename to test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs index 1a46a0fa..ce5a287f 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs +++ b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs @@ -1,8 +1,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; -using Tensorflow; using Tensorflow.Eager; -namespace TensorFlowNET.UnitTest.NativeAPI +namespace Tensorflow.Native.UnitTest.Eager { public partial class CApiEagerTest { diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs similarity index 97% rename from test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs rename to test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs index 372534ee..ad878115 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs +++ b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs @@ -1,8 +1,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; -using Tensorflow; using Tensorflow.Eager; -namespace TensorFlowNET.UnitTest.NativeAPI +namespace Tensorflow.Native.UnitTest.Eager { public partial class CApiEagerTest { diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandle.cs b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.TensorHandle.cs similarity index 93% rename from test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandle.cs rename to test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.TensorHandle.cs index 279b89cf..8f0c3b40 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandle.cs +++ b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.TensorHandle.cs @@ -1,8 +1,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; -using Tensorflow; using static Tensorflow.Binding; -namespace TensorFlowNET.UnitTest.NativeAPI +namespace Tensorflow.Native.UnitTest.Eager { public partial class CApiEagerTest { diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.TensorHandleDevices.cs similarity index 98% rename from test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs rename to test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.TensorHandleDevices.cs index c16dca9b..9fc8f95e 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs +++ b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.TensorHandleDevices.cs @@ -1,8 +1,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; -using Tensorflow; using Tensorflow.Eager; -namespace TensorFlowNET.UnitTest.NativeAPI +namespace Tensorflow.Native.UnitTest.Eager { public partial class CApiEagerTest { diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.Variables.cs similarity index 97% rename from test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs rename to test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.Variables.cs index cb32dee0..e6a091dc 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs +++ b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.Variables.cs @@ -1,9 +1,8 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; -using Tensorflow; using Tensorflow.Eager; using static Tensorflow.Binding; -namespace TensorFlowNET.UnitTest.NativeAPI +namespace Tensorflow.Native.UnitTest.Eager { public partial class CApiEagerTest { diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.cs similarity index 99% rename from test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs rename to test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.cs index 769a8eb5..a9dec9b1 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs +++ b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.cs @@ -1,10 +1,9 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; -using Tensorflow; using Tensorflow.Eager; using static Tensorflow.Binding; -namespace TensorFlowNET.UnitTest.NativeAPI +namespace Tensorflow.Native.UnitTest.Eager { /// /// tensorflow\c\eager\c_api_test.cc diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs b/test/TensorFlowNET.Native.UnitTest/Eager/GradientEagerTest.cs similarity index 96% rename from test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs rename to test/TensorFlowNET.Native.UnitTest/Eager/GradientEagerTest.cs index 66c35c1b..b4286bb7 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs +++ b/test/TensorFlowNET.Native.UnitTest/Eager/GradientEagerTest.cs @@ -1,13 +1,12 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; using System.Linq; -using Tensorflow; using static Tensorflow.Binding; -namespace TensorFlowNET.UnitTest.Gradient +namespace Tensorflow.Native.UnitTest.Eager { [TestClass] - public class GradientEagerTest : PythonTest + public class GradientEagerTest { [TestMethod] public void ConstantSquare() diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/GraphBuildTest.cs b/test/TensorFlowNET.Native.UnitTest/GraphBuildTest.cs similarity index 94% rename from test/TensorFlowNET.UnitTest/NativeAPI/GraphBuildTest.cs rename to test/TensorFlowNET.Native.UnitTest/GraphBuildTest.cs index 3d305a89..751dc355 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/GraphBuildTest.cs +++ b/test/TensorFlowNET.Native.UnitTest/GraphBuildTest.cs @@ -1,8 +1,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; -using Tensorflow; using static Tensorflow.Binding; -namespace TensorFlowNET.UnitTest.NativeAPI +namespace Tensorflow.Native.UnitTest { [TestClass] public class GraphBuildTest : CApiTest diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/GraphTest.cs b/test/TensorFlowNET.Native.UnitTest/GraphTest.cs similarity index 99% rename from test/TensorFlowNET.UnitTest/NativeAPI/GraphTest.cs rename to test/TensorFlowNET.Native.UnitTest/GraphTest.cs index 4b75becc..e40fd5c8 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/GraphTest.cs +++ b/test/TensorFlowNET.Native.UnitTest/GraphTest.cs @@ -1,10 +1,8 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; -using Tensorflow; using static Tensorflow.Binding; -using Buffer = Tensorflow.Buffer; -namespace TensorFlowNET.UnitTest.NativeAPI +namespace Tensorflow.Native.UnitTest { [TestClass] public class GraphTest : CApiTest diff --git a/test/TensorFlowNET.Native.UnitTest/Sessions/SessionTest.cs b/test/TensorFlowNET.Native.UnitTest/Sessions/SessionTest.cs new file mode 100644 index 00000000..9ae2e379 --- /dev/null +++ b/test/TensorFlowNET.Native.UnitTest/Sessions/SessionTest.cs @@ -0,0 +1,74 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; + +namespace Tensorflow.Native.UnitTest.Sessions +{ + [TestClass, Ignore] + public class SessionTest : CApiTest + { + /// + /// tensorflow\c\c_api_test.cc + /// `TEST(CAPI, Session)` + /// + [TestMethod] + public void Session() + { + using var s = new Status(); + using var graph = new Graph(); + + // Make a placeholder operation. + var feed = c_test_util.Placeholder(graph, s); + + // Make a constant operation with the scalar "2". + var two = c_test_util.ScalarConst(2, graph, s); + + // Add operation. + var add = c_test_util.Add(feed, two, graph, s); + + var csession = new CSession(graph, s); + ASSERT_EQ(TF_Code.TF_OK, s.Code); + + // Run the graph. + var inputs = new Dictionary(); + inputs.Add(feed, new Tensor(3)); + csession.SetInputs(inputs); + + var outputs = new TF_Output[] { new TF_Output(add, 0) }; + csession.SetOutputs(outputs); + + csession.Run(s); + Tensor outTensor = csession.output_tensor(0); + EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); + EXPECT_EQ(0, outTensor.NDims); + ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); + var output_contents = outTensor.ToArray(); + EXPECT_EQ(3 + 2, output_contents[0]); + + // Add another operation to the graph. + var neg = c_test_util.Neg(add, graph, s); + ASSERT_EQ(TF_Code.TF_OK, s.Code); + + // Run up to the new operation. + inputs = new Dictionary(); + inputs.Add(feed, new Tensor(7)); + csession.SetInputs(inputs); + outputs = new TF_Output[] { new TF_Output(neg, 0) }; + csession.SetOutputs(outputs); + csession.Run(s); + ASSERT_EQ(TF_Code.TF_OK, s.Code); + + outTensor = csession.output_tensor(0); + ASSERT_TRUE(outTensor != IntPtr.Zero); + EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); + EXPECT_EQ(0, outTensor.NDims); // scalar + ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); + output_contents = outTensor.ToArray(); + EXPECT_EQ(-(7 + 2), output_contents[0]); + + // Clean up + csession.CloseAndDelete(s); + ASSERT_EQ(TF_Code.TF_OK, s.Code); + } + } +} diff --git a/test/TensorFlowNET.Native.UnitTest/Tensorflow.Native.UnitTest.csproj b/test/TensorFlowNET.Native.UnitTest/Tensorflow.Native.UnitTest.csproj new file mode 100644 index 00000000..05b28723 --- /dev/null +++ b/test/TensorFlowNET.Native.UnitTest/Tensorflow.Native.UnitTest.csproj @@ -0,0 +1,36 @@ + + + + netcoreapp3.1 + + false + + AnyCPU;x64 + + + + true + DEBUG;TRACE + x64 + + + + true + DEBUG;TRACE + x64 + + + + + + + + + + + + + + + + diff --git a/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs b/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs new file mode 100644 index 00000000..b7a208e4 --- /dev/null +++ b/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs @@ -0,0 +1,204 @@ +using FluentAssertions; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp; +using System; +using System.Linq; +using System.Runtime.InteropServices; +using static Tensorflow.Binding; + +namespace Tensorflow.Native.UnitTest.Tensors +{ + [TestClass] + public class TensorTest : CApiTest + { + [TestMethod] + public unsafe void TensorFromFixed() + { + var array = new float[1000]; + var span = new Span(array, 100, 500); + fixed (float* ptr = &MemoryMarshal.GetReference(span)) + { + using (var t = new Tensor((IntPtr)ptr, new long[] { span.Length }, tf.float32, 4 * span.Length)) + { + Assert.IsFalse(t.IsDisposed); + Assert.AreEqual(2000, (int)t.bytesize); + } + } + + fixed (float* ptr = &array[0]) + { + using (var t = new Tensor((IntPtr)ptr, new long[] { array.Length }, tf.float32, 4 * array.Length)) + { + Assert.IsFalse(t.IsDisposed); + Assert.AreEqual(4000, (int)t.bytesize); + } + } + } + + [TestMethod] + public void TensorFromArray() + { + var array = new float[1000]; + using (var t = new Tensor(array, new long[] { array.Length }, tf.float32)) + { + Assert.IsFalse(t.IsDisposed); + Assert.AreEqual(1000 * sizeof(float), (int)t.bytesize); + } + + using (var t = new Tensor(new float[] { 1 }, new long[] { 1 }, tf.float32)) + { + Assert.IsFalse(t.IsDisposed); + Assert.AreEqual(1 * sizeof(float), (int)t.bytesize); + } + + using (var t = new Tensor(new float[] { 1 }, null, tf.float32)) + { + Assert.IsFalse(t.IsDisposed); + Assert.AreEqual(1 * sizeof(float), (int)t.bytesize); + t.shape.Should().BeEmpty(); + } + } + + [TestMethod] + public void AllocateTensor() + { + ulong num_bytes = 6 * sizeof(float); + long[] dims = { 2, 3 }; + Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes); + EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype); + EXPECT_EQ(2, t.NDims); + EXPECT_EQ((int)dims[0], t.shape[0]); + EXPECT_EQ(num_bytes, t.bytesize); + t.Dispose(); + } + + + /// + /// Port from c_api_test.cc + /// `TEST(CAPI, MaybeMove)` + /// + [TestMethod, Ignore] + public void MaybeMove() + { + NDArray nd = np.array(2, 3); + Tensor t = new Tensor(nd); + Tensor o = t.MaybeMove(); + ASSERT_TRUE(o == IntPtr.Zero); // It is unsafe to move memory TF might not own. + t.Dispose(); + } + + /// + /// Port from c_api_test.cc + /// `TEST(CAPI, Tensor)` + /// + [TestMethod] + public void Tensor() + { + var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); + + var tensor = new Tensor(nd); + var array = tensor.ToArray(); + + EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT); + EXPECT_EQ(tensor.rank, nd.ndim); + EXPECT_EQ((int)tensor.shape[0], nd.shape[0]); + EXPECT_EQ((int)tensor.shape[1], nd.shape[1]); + EXPECT_EQ(tensor.bytesize, (ulong)nd.size * sizeof(float)); + Assert.IsTrue(Enumerable.SequenceEqual(nd.Data(), new float[] { 1, 2, 3, 4, 5, 6 })); + } + + /// + /// Port from tensorflow\c\c_api_test.cc + /// `TEST(CAPI, SetShape)` + /// + [TestMethod] + public void SetShape() + { + var s = new Status(); + var graph = new Graph().as_default(); + + var feed = c_test_util.Placeholder(graph, s); + var feed_out_0 = new TF_Output(feed, 0); + + // Fetch the shape, it should be completely unknown. + int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle); + + Assert.IsTrue(s.Code == TF_Code.TF_OK); + EXPECT_EQ(-1, num_dims); + + // Set the shape to be unknown, expect no change. + c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s.Handle); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle); + EXPECT_EQ(-1, num_dims); + + // Set the shape to be 2 x Unknown + long[] dims = { 2, -1 }; + c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s.Handle); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle); + EXPECT_EQ(2, num_dims); + + // Get the dimension vector appropriately. + var returned_dims = new long[dims.Length]; + c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); + + // Set to a new valid shape: [2, 3] + dims[1] = 3; + c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s.Handle); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + + // Fetch and see that the new value is returned. + c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); + + // Try to set 'unknown' with unknown rank on the shape and see that + // it doesn't change. + c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s.Handle); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + EXPECT_EQ(2, num_dims); + EXPECT_EQ(2, (int)returned_dims[0]); + EXPECT_EQ(3, (int)returned_dims[1]); + + // Try to set 'unknown' with same rank on the shape and see that + // it doesn't change. + dims[0] = -1; + dims[1] = -1; + c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s.Handle); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + EXPECT_EQ(2, num_dims); + EXPECT_EQ(2, (int)returned_dims[0]); + EXPECT_EQ(3, (int)returned_dims[1]); + + // Try to fetch a shape with the wrong num_dims + c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s.Handle); + Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); + + // Try to set an invalid shape (cannot change 2x3 to a 2x5). + dims[1] = 5; + c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s.Handle); + Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); + + // Test for a scalar. + var three = c_test_util.ScalarConst(3, graph, s); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + var three_out_0 = new TF_Output(three, 0); + + num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s.Handle); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + EXPECT_EQ(0, num_dims); + c_api.TF_GraphGetTensorShape(graph, feed_out_0, dims, num_dims, s.Handle); + Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); + + graph.Exit(); + s.Dispose(); + } + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/c_test_util.cs b/test/TensorFlowNET.Native.UnitTest/c_test_util.cs similarity index 98% rename from test/TensorFlowNET.UnitTest/NativeAPI/c_test_util.cs rename to test/TensorFlowNET.Native.UnitTest/c_test_util.cs index fea762aa..1a43b3e1 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/c_test_util.cs +++ b/test/TensorFlowNET.Native.UnitTest/c_test_util.cs @@ -1,10 +1,8 @@ using System; using System.Diagnostics.CodeAnalysis; -using Tensorflow; using Tensorflow.Util; -using Buffer = Tensorflow.Buffer; -namespace TensorFlowNET.UnitTest +namespace Tensorflow.Native.UnitTest { /// /// Port from `tensorflow\c\c_test_util.cc` diff --git a/test/TensorFlowNET.UnitTest/Basics/SessionTest.cs b/test/TensorFlowNET.UnitTest/Basics/SessionTest.cs index 0bc38dd1..b95a7c6a 100644 --- a/test/TensorFlowNET.UnitTest/Basics/SessionTest.cs +++ b/test/TensorFlowNET.UnitTest/Basics/SessionTest.cs @@ -8,78 +8,11 @@ using Tensorflow; using Tensorflow.Util; using static Tensorflow.Binding; -namespace TensorFlowNET.UnitTest.NativeAPI +namespace TensorFlowNET.UnitTest { - [TestClass] - public class SessionTest : CApiTest + [TestClass, Ignore] + public class SessionTest { - /// - /// tensorflow\c\c_api_test.cc - /// `TEST(CAPI, Session)` - /// - [TestMethod, Ignore] - public void Session() - { - lock (Locks.ProcessWide) - { - var s = new Status(); - var graph = new Graph().as_default(); - - // Make a placeholder operation. - var feed = c_test_util.Placeholder(graph, s); - - // Make a constant operation with the scalar "2". - var two = c_test_util.ScalarConst(2, graph, s); - - // Add operation. - var add = c_test_util.Add(feed, two, graph, s); - - var csession = new CSession(graph, s); - ASSERT_EQ(TF_Code.TF_OK, s.Code); - - // Run the graph. - var inputs = new Dictionary(); - inputs.Add(feed, new Tensor(3)); - csession.SetInputs(inputs); - - var outputs = new TF_Output[] { new TF_Output(add, 0) }; - csession.SetOutputs(outputs); - - csession.Run(s); - Tensor outTensor = csession.output_tensor(0); - EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); - EXPECT_EQ(0, outTensor.NDims); - ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); - var output_contents = outTensor.ToArray(); - EXPECT_EQ(3 + 2, output_contents[0]); - - // Add another operation to the graph. - var neg = c_test_util.Neg(add, graph, s); - ASSERT_EQ(TF_Code.TF_OK, s.Code); - - // Run up to the new operation. - inputs = new Dictionary(); - inputs.Add(feed, new Tensor(7)); - csession.SetInputs(inputs); - outputs = new TF_Output[] { new TF_Output(neg, 0) }; - csession.SetOutputs(outputs); - csession.Run(s); - ASSERT_EQ(TF_Code.TF_OK, s.Code); - - outTensor = csession.output_tensor(0); - ASSERT_TRUE(outTensor != IntPtr.Zero); - EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); - EXPECT_EQ(0, outTensor.NDims); // scalar - ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); - output_contents = outTensor.ToArray(); - EXPECT_EQ(-(7 + 2), output_contents[0]); - - // Clean up - csession.CloseAndDelete(s); - ASSERT_EQ(TF_Code.TF_OK, s.Code); - } - } - [TestMethod] public void EvalTensor() { diff --git a/test/TensorFlowNET.UnitTest/Basics/TensorTest.cs b/test/TensorFlowNET.UnitTest/Basics/TensorTest.cs index a81074b3..c562f4e6 100644 --- a/test/TensorFlowNET.UnitTest/Basics/TensorTest.cs +++ b/test/TensorFlowNET.UnitTest/Basics/TensorTest.cs @@ -7,201 +7,11 @@ using System.Runtime.InteropServices; using Tensorflow; using static Tensorflow.Binding; -namespace TensorFlowNET.UnitTest.NativeAPI +namespace TensorFlowNET.UnitTest { - [TestClass] - public class TensorTest : CApiTest + [TestClass, Ignore] + public class TensorTest { - [TestMethod] - public unsafe void TensorFromFixed() - { - var array = new float[1000]; - var span = new Span(array, 100, 500); - fixed (float* ptr = &MemoryMarshal.GetReference(span)) - { - using (var t = new Tensor((IntPtr)ptr, new long[] { span.Length }, tf.float32, 4 * span.Length)) - { - Assert.IsFalse(t.IsDisposed); - Assert.AreEqual(2000, (int)t.bytesize); - } - } - - fixed (float* ptr = &array[0]) - { - using (var t = new Tensor((IntPtr)ptr, new long[] { array.Length }, tf.float32, 4 * array.Length)) - { - Assert.IsFalse(t.IsDisposed); - Assert.AreEqual(4000, (int)t.bytesize); - } - } - } - - [TestMethod] - public unsafe void TensorFromArray() - { - var array = new float[1000]; - using (var t = new Tensor(array, new long[] { array.Length }, tf.float32)) - { - Assert.IsFalse(t.IsDisposed); - Assert.AreEqual(1000 * sizeof(float), (int)t.bytesize); - } - - using (var t = new Tensor(new float[] { 1 }, new long[] { 1 }, tf.float32)) - { - Assert.IsFalse(t.IsDisposed); - Assert.AreEqual(1 * sizeof(float), (int)t.bytesize); - } - - using (var t = new Tensor(new float[] { 1 }, null, tf.float32)) - { - Assert.IsFalse(t.IsDisposed); - Assert.AreEqual(1 * sizeof(float), (int)t.bytesize); - t.shape.Should().BeEmpty(); - } - } - - [TestMethod] - public void AllocateTensor() - { - ulong num_bytes = 6 * sizeof(float); - long[] dims = { 2, 3 }; - Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes); - EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype); - EXPECT_EQ(2, t.NDims); - EXPECT_EQ((int)dims[0], t.shape[0]); - EXPECT_EQ(num_bytes, t.bytesize); - t.Dispose(); - } - - - /// - /// Port from c_api_test.cc - /// `TEST(CAPI, MaybeMove)` - /// - [TestMethod, Ignore] - public void MaybeMove() - { - NDArray nd = np.array(2, 3); - Tensor t = new Tensor(nd); - Tensor o = t.MaybeMove(); - ASSERT_TRUE(o == IntPtr.Zero); // It is unsafe to move memory TF might not own. - t.Dispose(); - } - - /// - /// Port from c_api_test.cc - /// `TEST(CAPI, Tensor)` - /// - [TestMethod] - public void Tensor() - { - var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); - - var tensor = new Tensor(nd); - var array = tensor.ToArray(); - - EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT); - EXPECT_EQ(tensor.rank, nd.ndim); - EXPECT_EQ((int)tensor.shape[0], nd.shape[0]); - EXPECT_EQ((int)tensor.shape[1], nd.shape[1]); - EXPECT_EQ(tensor.bytesize, (ulong)nd.size * sizeof(float)); - Assert.IsTrue(Enumerable.SequenceEqual(nd.Data(), new float[] { 1, 2, 3, 4, 5, 6 })); - } - - /// - /// Port from tensorflow\c\c_api_test.cc - /// `TEST(CAPI, SetShape)` - /// - [TestMethod] - public void SetShape() - { - var s = new Status(); - var graph = new Graph().as_default(); - - var feed = c_test_util.Placeholder(graph, s); - var feed_out_0 = new TF_Output(feed, 0); - - // Fetch the shape, it should be completely unknown. - int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle); - - Assert.IsTrue(s.Code == TF_Code.TF_OK); - EXPECT_EQ(-1, num_dims); - - // Set the shape to be unknown, expect no change. - c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s.Handle); - Assert.IsTrue(s.Code == TF_Code.TF_OK); - num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle); - EXPECT_EQ(-1, num_dims); - - // Set the shape to be 2 x Unknown - long[] dims = { 2, -1 }; - c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s.Handle); - Assert.IsTrue(s.Code == TF_Code.TF_OK); - num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle); - EXPECT_EQ(2, num_dims); - - // Get the dimension vector appropriately. - var returned_dims = new long[dims.Length]; - c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); - Assert.IsTrue(s.Code == TF_Code.TF_OK); - Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); - - // Set to a new valid shape: [2, 3] - dims[1] = 3; - c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s.Handle); - Assert.IsTrue(s.Code == TF_Code.TF_OK); - - // Fetch and see that the new value is returned. - c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); - Assert.IsTrue(s.Code == TF_Code.TF_OK); - Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); - - // Try to set 'unknown' with unknown rank on the shape and see that - // it doesn't change. - c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s.Handle); - Assert.IsTrue(s.Code == TF_Code.TF_OK); - c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); - Assert.IsTrue(s.Code == TF_Code.TF_OK); - EXPECT_EQ(2, num_dims); - EXPECT_EQ(2, (int)returned_dims[0]); - EXPECT_EQ(3, (int)returned_dims[1]); - - // Try to set 'unknown' with same rank on the shape and see that - // it doesn't change. - dims[0] = -1; - dims[1] = -1; - c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s.Handle); - Assert.IsTrue(s.Code == TF_Code.TF_OK); - c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); - Assert.IsTrue(s.Code == TF_Code.TF_OK); - EXPECT_EQ(2, num_dims); - EXPECT_EQ(2, (int)returned_dims[0]); - EXPECT_EQ(3, (int)returned_dims[1]); - - // Try to fetch a shape with the wrong num_dims - c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s.Handle); - Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); - - // Try to set an invalid shape (cannot change 2x3 to a 2x5). - dims[1] = 5; - c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s.Handle); - Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); - - // Test for a scalar. - var three = c_test_util.ScalarConst(3, graph, s); - Assert.IsTrue(s.Code == TF_Code.TF_OK); - var three_out_0 = new TF_Output(three, 0); - - num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s.Handle); - Assert.IsTrue(s.Code == TF_Code.TF_OK); - EXPECT_EQ(0, num_dims); - c_api.TF_GraphGetTensorShape(graph, feed_out_0, null, num_dims, s.Handle); - //Assert.IsTrue(s.Code == TF_Code.TF_OK); - - // graph.Dispose(); - s.Dispose(); - } - [TestMethod] public void sparse_to_dense() { @@ -271,32 +81,5 @@ namespace TensorFlowNET.UnitTest.NativeAPI Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, masked.ToArray())); } } - - /// - /// Creates a tensor from an image of 256x256x3 and resizes it to 100x100x3 - /// - [TestMethod] - public unsafe void tensor_resize() - { - tf.enable_eager_execution(); - - var imageArray = new float[256 * 256 * 3]; - - using var newSize = tf.convert_to_tensor(new int[] { 100, 100 }); - - using (var t = tf.constant(imageArray, tf.float32, (1, 256, 256, 3))) - { - Assert.IsFalse(t.IsDisposed); - Assert.AreEqual(256 * 256 * 3 * sizeof(float), (int)t.bytesize); - - using var resized = tf.image.resize_bilinear(t, newSize); - EXPECT_EQ(resized.shape[0], 1); - EXPECT_EQ(resized.shape[1], 100); - EXPECT_EQ(resized.shape[2], 100); - EXPECT_EQ(resized.shape[3], 3); - } - - tf.compat.v1.disable_eager_execution(); - } } } \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs b/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs index a83800e6..bd25736c 100644 --- a/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs +++ b/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs @@ -1,4 +1,5 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest @@ -10,11 +11,40 @@ namespace TensorFlowNET.UnitTest { if (!tf.executing_eagerly()) tf.enable_eager_execution(); + tf.Context.ensure_initialized(); } [TestCleanup] public void TestClean() { } + + public bool Equal(float[] f1, float[] f2) + { + bool ret = false; + var tolerance = .000001f; + for (var i = 0; i < f1.Length; i++) + { + ret = Math.Abs(f1[i] - f2[i]) <= tolerance; + if (!ret) + break; + } + + return ret; + } + + public bool Equal(double[] d1, double[] d2) + { + bool ret = false; + var tolerance = .000000000000001f; + for (var i = 0; i < d1.Length; i++) + { + ret = Math.Abs(d1[i] - d2[i]) <= tolerance; + if (!ret) + break; + } + + return ret; + } } } diff --git a/test/TensorFlowNET.UnitTest/EnforcedSinglethreadingTests.cs b/test/TensorFlowNET.UnitTest/EnforcedSinglethreadingTests.cs index 0ac5f33e..c8634542 100644 --- a/test/TensorFlowNET.UnitTest/EnforcedSinglethreadingTests.cs +++ b/test/TensorFlowNET.UnitTest/EnforcedSinglethreadingTests.cs @@ -6,10 +6,10 @@ using System.Threading; using Tensorflow; using static Tensorflow.Binding; -namespace TensorFlowNET.UnitTest.NativeAPI +namespace TensorFlowNET.UnitTest { [TestClass] - public class EnforcedSinglethreadingTests : CApiTest + public class EnforcedSinglethreadingTests { private static readonly object _singlethreadLocker = new object(); diff --git a/test/TensorFlowNET.UnitTest/GraphModeTestBase.cs b/test/TensorFlowNET.UnitTest/GraphModeTestBase.cs index 67f4e3a7..8d008ddb 100644 --- a/test/TensorFlowNET.UnitTest/GraphModeTestBase.cs +++ b/test/TensorFlowNET.UnitTest/GraphModeTestBase.cs @@ -1,6 +1,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using TensorFlowNET.UnitTest; using static Tensorflow.Binding; +using static Tensorflow.KerasApi; namespace Tensorflow.UnitTest { @@ -15,6 +16,7 @@ namespace Tensorflow.UnitTest [TestCleanup] public void TestClean() { + keras.backend.clear_session(); tf.enable_eager_execution(); } } diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/ActivationFunctionTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/ActivationFunctionTest.cs index bac400d8..6f816d8f 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/ActivationFunctionTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/ActivationFunctionTest.cs @@ -5,7 +5,7 @@ using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest.nn_test { [TestClass] - public class ActivationFunctionTest : TFNetApiTest + public class ActivationFunctionTest : EagerModeTestBase { // A constant vector of size 6 Tensor a = tf.constant(new float[] { 1.0f, -0.5f, 3.4f, -2.1f, 0.0f, -6.5f }); diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/BitwiseApiTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/BitwiseApiTest.cs index 4ec4eb25..e57e5072 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/BitwiseApiTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/BitwiseApiTest.cs @@ -6,7 +6,7 @@ using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest.ManagedAPI { [TestClass] - public class BitwiseApiTest : TFNetApiTest + public class BitwiseApiTest : EagerModeTestBase { [TestInitialize] public void Init() diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/FunctionApiTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/FunctionApiTest.cs index e821d9e7..df00d588 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/FunctionApiTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/FunctionApiTest.cs @@ -7,7 +7,7 @@ using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest.ManagedAPI { [TestClass] - public class FunctionApiTest : TFNetApiTest + public class FunctionApiTest : EagerModeTestBase { Tensor Min(Tensor a, Tensor b) { diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/MathApiTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/MathApiTest.cs index 874ab053..26e89404 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/MathApiTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/MathApiTest.cs @@ -6,7 +6,7 @@ using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest.ManagedAPI { [TestClass] - public class MathApiTest : TFNetApiTest + public class MathApiTest : EagerModeTestBase { // A constant vector of size 6 Tensor a = tf.constant(new float[] { 1.0f, -0.5f, 3.4f, -2.1f, 0.0f, -6.5f }); diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/TFNetApiTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/TFNetApiTest.cs deleted file mode 100644 index 5fd20899..00000000 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/TFNetApiTest.cs +++ /dev/null @@ -1,35 +0,0 @@ -using System; - -namespace TensorFlowNET.UnitTest -{ - public class TFNetApiTest - { - public bool Equal(float[] f1, float[] f2) - { - bool ret = false; - var tolerance = .000001f; - for (var i = 0; i < f1.Length; i++) - { - ret = Math.Abs(f1[i] - f2[i]) <= tolerance; - if (!ret) - break; - } - - return ret; - } - - public bool Equal(double[] d1, double[] d2) - { - bool ret = false; - var tolerance = .000000000000001f; - for (var i = 0; i < d1.Length; i++) - { - ret = Math.Abs(d1[i] - d2[i]) <= tolerance; - if (!ret) - break; - } - - return ret; - } - } -} diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/ZeroFractionTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/ZeroFractionTest.cs deleted file mode 100644 index 349a6c39..00000000 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/ZeroFractionTest.cs +++ /dev/null @@ -1,86 +0,0 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; -using NumSharp; -using System; -using System.Linq; -using Tensorflow; -using Tensorflow.UnitTest; - -namespace TensorFlowNET.UnitTest.nn_test -{ - [TestClass] - public class ZeroFractionTest : GraphModeTestBase - { - protected double _ZeroFraction(NDArray x) - { - assert(x.shape); - int total_elements = np.prod(x.shape); - - var eps = 1e-8; - var nonzeros = x.Data().Count(d => Math.Abs(d) > eps); - return 1.0 - nonzeros / (double)total_elements; - } - - [Ignore("TODO implement nn_impl.zero_fraction")] - [TestMethod] - public void testZeroFraction() - { - var x_shape = new Shape(5, 17); - var x_np = np.random.randint(0, 2, x_shape); - //x_np.astype(np.float32); - var y_np = this._ZeroFraction(x_np); - - var x_tf = constant_op.constant(x_np); - x_tf.set_shape(x_shape); - var y_tf = nn_impl.zero_fraction(x_tf); - var y_tf_np = self.evaluate(y_tf); - - var eps = 1e-8; - self.assertAllClose(y_tf_np, y_np, eps); - } - - [Ignore("TODO implement nn_impl.zero_fraction")] - [TestMethod] - public void testZeroFractionEmpty() - { - - var x = np.zeros(0); - var y = self.evaluate(nn_impl.zero_fraction(new Tensor(x))); - self.assertTrue(np.isnan(y)); - } - - [Ignore("TODO implement nn_impl.zero_fraction")] - [TestMethod] - public void testZeroFraction2_27Zeros() - { - var sparsity = nn_impl.zero_fraction( - array_ops.zeros(new Shape((int)Math.Pow(2, 27 * 1.01)), dtypes.int8)); - self.assertAllClose(1.0, self.evaluate(sparsity)); - } - - [Ignore("TODO implement nn_impl.zero_fraction")] - [TestMethod] - public void testZeroFraction2_27Ones() - { - var sparsity = nn_impl.zero_fraction( - array_ops.ones(new TensorShape((int)Math.Pow(2, 27 * 1.01)), dtypes.int8)); - self.assertAllClose(0.0, self.evaluate(sparsity)); - } - - [Ignore("TODO implement nn_impl.zero_fraction")] - [TestMethod] - public void testUnknownSize() - { - var value = array_ops.placeholder(dtype: dtypes.float32); - var sparsity = nn_impl.zero_fraction(value); - using (var sess = self.cached_session()) - { - // TODO: make this compile - //self.assertAllClose( - // 0.25, - // sess.run(sparsity, {value: [[0., 1.], [0.3, 2.]]})); - } - } - - - } -} diff --git a/test/TensorFlowNET.UnitTest/MultithreadingTests.cs b/test/TensorFlowNET.UnitTest/MultithreadingTests.cs index 91580d31..10f32ca2 100644 --- a/test/TensorFlowNET.UnitTest/MultithreadingTests.cs +++ b/test/TensorFlowNET.UnitTest/MultithreadingTests.cs @@ -73,6 +73,7 @@ namespace TensorFlowNET.UnitTest tf.peak_default_graph().Should().BeNull(); var beforehand = tf.get_default_graph(); //this should create default automatically. beforehand.graph_key.Should().NotContain("-0/", "Already created a graph in an other thread."); + beforehand.as_default(); tf.peak_default_graph().Should().NotBeNull(); using (var sess = tf.Session()) diff --git a/test/TensorFlowNET.UnitTest/Utilities/TestHelper.cs b/test/TensorFlowNET.UnitTest/Utilities/TestHelper.cs index ed379f3c..d1cda728 100644 --- a/test/TensorFlowNET.UnitTest/Utilities/TestHelper.cs +++ b/test/TensorFlowNET.UnitTest/Utilities/TestHelper.cs @@ -1,4 +1,5 @@ -using System.IO; +using System; +using System.IO; namespace TensorFlowNET.UnitTest { @@ -6,8 +7,16 @@ namespace TensorFlowNET.UnitTest { public static string GetFullPathFromDataDir(string fileName) { - var dir = Path.Combine(Directory.GetCurrentDirectory(), "..", "..", "..", "..", "..", "data"); - return Path.GetFullPath(Path.Combine(dir, fileName)); + var dataDir = GetRootContentDir(Directory.GetCurrentDirectory()); + return Path.Combine(dataDir, fileName); + } + + static string GetRootContentDir(string dir) + { + var path = Path.GetFullPath(Path.Combine(dir, "data")); + if (Directory.Exists(path)) + return path; + return GetRootContentDir(Path.GetFullPath(Path.Combine(dir, ".."))); } } } diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/control_flow_ops_test.py b/test/TensorFlowNET.UnitTest/control_flow_ops_test/control_flow_ops_test.py deleted file mode 100644 index f1dd4f52..00000000 --- a/test/TensorFlowNET.UnitTest/control_flow_ops_test/control_flow_ops_test.py +++ /dev/null @@ -1,1059 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for control_flow_ops.py.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections -import numpy as np - -from tensorflow.python import tf2 -from tensorflow.core.framework import graph_pb2 -from tensorflow.core.framework import node_def_pb2 -from tensorflow.python.eager import def_function -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import test_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import embedding_ops -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import tensor_array_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import -from tensorflow.python.platform import googletest -from tensorflow.python.training import momentum -from tensorflow.python.util import nest - - -TestTuple = collections.namedtuple("TestTuple", "a b") -SingletonTestTuple = collections.namedtuple("SingletonTestTuple", "a") - - -class GroupTestCase(test_util.TensorFlowTestCase): - - def _StripNode(self, nd): - snode = node_def_pb2.NodeDef(name=nd.name, op=nd.op, input=nd.input) - if nd.device: - snode.device = nd.device - return snode - - def _StripGraph(self, gd): - """Copy gd keeping only, node.name, node.op, node.input, and node.device.""" - return graph_pb2.GraphDef(node=[self._StripNode(nd) for nd in gd.node]) - - def testGroup_NoDevices(self): - with ops.Graph().as_default() as g: - a = constant_op.constant(0, name="a") - b = constant_op.constant(0, name="b") - c = constant_op.constant(0, name="c") - control_flow_ops.group(a.op, b.op, c.op, name="root") - gd = g.as_graph_def() - self.assertProtoEquals(""" - node { name: "a" op: "Const"} - node { name: "b" op: "Const"} - node { name: "c" op: "Const"} - node { name: "root" op: "NoOp" input: "^a" input: "^b" input: "^c" } - """, self._StripGraph(gd)) - - def testGroup_OneDevice(self): - with ops.Graph().as_default() as g: - with g.device("/task:0"): - a = constant_op.constant(0, name="a") - b = constant_op.constant(0, name="b") - control_flow_ops.group(a.op, b.op, name="root") - gd = g.as_graph_def() - self.assertProtoEquals(""" - node { name: "a" op: "Const" device: "/task:0" } - node { name: "b" op: "Const" device: "/task:0" } - node { name: "root" op: "NoOp" input: "^a" input: "^b" device: "/task:0" } - """, self._StripGraph(gd)) - - def testGroup_MultiDevice(self): - with ops.Graph().as_default() as g: - with g.device("/task:0"): - a = constant_op.constant(0, name="a") - b = constant_op.constant(0, name="b") - with g.device("/task:1"): - c = constant_op.constant(0, name="c") - d = constant_op.constant(0, name="d") - with g.device("/task:2"): - control_flow_ops.group(a.op, b.op, c.op, d.op, name="root") - gd = g.as_graph_def() - self.assertProtoEquals(""" - node { name: "a" op: "Const" device: "/task:0"} - node { name: "b" op: "Const" device: "/task:0"} - node { name: "c" op: "Const" device: "/task:1"} - node { name: "d" op: "Const" device: "/task:1"} - node { name: "root/NoOp" op: "NoOp" input: "^a" input: "^b" - device: "/task:0" } - node { name: "root/NoOp_1" op: "NoOp" input: "^c" input: "^d" - device: "/task:1" } - node { name: "root" op: "NoOp" input: "^root/NoOp" input: "^root/NoOp_1" - device: "/task:2" } - """, self._StripGraph(gd)) - - def testPassingList(self): - with ops.Graph().as_default() as g: - a = constant_op.constant(0, name="a") - b = constant_op.constant(0, name="b") - control_flow_ops.group([a.op, b.op], name="root") - gd = g.as_graph_def() - self.assertProtoEquals(""" - node { name: "a" op: "Const"} - node { name: "b" op: "Const"} - node { name: "root" op: "NoOp" input: "^a" input: "^b" } - """, self._StripGraph(gd)) - - @test_util.run_deprecated_v1 - def testPassingNonTensors(self): - with self.assertRaises(TypeError): - control_flow_ops.group(1, 2) - - -class ShapeTestCase(test_util.TensorFlowTestCase): - - def testShape(self): - tensor = constant_op.constant([1.0, 2.0]) - self.assertEquals([2], tensor.get_shape()) - self.assertEquals([2], - control_flow_ops.with_dependencies( - [constant_op.constant(1.0)], tensor).get_shape()) - - -class WithDependenciesTestCase(test_util.TensorFlowTestCase): - - @test_util.run_deprecated_v1 - def testTupleDependencies(self): - counter = variable_scope.get_variable( - "my_counter", shape=[], initializer=init_ops.zeros_initializer()) - increment_counter = state_ops.assign_add(counter, 1) - const_with_dep = control_flow_ops.with_dependencies( - (increment_counter, constant_op.constant(42)), - constant_op.constant(7)) - - self.evaluate(variables.global_variables_initializer()) - self.assertEquals(0, self.evaluate(counter)) - self.assertEquals(7, self.evaluate(const_with_dep)) - self.assertEquals(1, self.evaluate(counter)) - - @test_util.run_deprecated_v1 - def testListDependencies(self): - counter = variable_scope.get_variable( - "my_counter", shape=[], initializer=init_ops.zeros_initializer()) - increment_counter = state_ops.assign_add(counter, 1) - const_with_dep = control_flow_ops.with_dependencies( - [increment_counter, constant_op.constant(42)], - constant_op.constant(7)) - - self.evaluate(variables.global_variables_initializer()) - self.assertEquals(0, self.evaluate(counter)) - self.assertEquals(7, self.evaluate(const_with_dep)) - self.assertEquals(1, self.evaluate(counter)) - - -class SwitchTestCase(test_util.TensorFlowTestCase): - - @test_util.run_deprecated_v1 - def testIndexedSlicesWithDenseShape(self): - with self.cached_session(): - data = ops.IndexedSlices( - constant_op.constant([1, 2, 3]), - constant_op.constant([0, 1]), - dense_shape=constant_op.constant([3])) - zero = constant_op.constant(0) - one = constant_op.constant(1) - less_op = math_ops.less(zero, one) - _, switch_true = control_flow_ops.switch(data, less_op) - self.assertAllEqual([1, 2, 3], switch_true.values.eval()) - self.assertAllEqual([0, 1], switch_true.indices.eval()) - - @test_util.run_deprecated_v1 - def testIndexedSlicesGradient(self): - embedding_matrix = variable_scope.get_variable( - "embedding_matrix", [5, 5], - initializer=init_ops.random_normal_initializer()) - - def cond(it, _): - return it < 5 - - def body(it, cost): - embedding = embedding_ops.embedding_lookup(embedding_matrix + 0.0, [0]) - cost += math_ops.reduce_sum(embedding) - return it + 1, cost - - _, cost = control_flow_ops.while_loop( - cond, body, [constant_op.constant(0), - constant_op.constant(0.0)]) - optimizer = momentum.MomentumOptimizer(0.1, 0.9) - train_op = optimizer.minimize(cost) - with self.cached_session(): - self.evaluate(variables.global_variables_initializer()) - for _ in range(10): - self.evaluate([train_op]) - - def testResourceReadInLoop(self): - embedding_matrix = variable_scope.get_variable( - "embedding_matrix", initializer=[[2.0], [3.0]], use_resource=True) - - def cond(it, _): - return it < 5 - - def body(it, cost): - embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) - cost += math_ops.reduce_sum(embedding) - return it + 1, cost - - _, cost = control_flow_ops.while_loop( - cond, body, [constant_op.constant(0), - constant_op.constant(0.0)]) - with self.cached_session(): - self.evaluate(variables.global_variables_initializer()) - self.assertAllEqual(10.0, self.evaluate(cost)) - - def doTestIndexedSlicesGradientInCondInWhileLoop(self, use_resource=False): - embedding_matrix = variable_scope.get_variable( - "embedding_matrix", [5, 5], - initializer=init_ops.random_normal_initializer(), - use_resource=use_resource) - - def cond(it, _): - return it < 5 - - def body(it, cost): - embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) - cost = control_flow_ops.cond( - math_ops.equal(it, 3), lambda: math_ops.square(cost), - (lambda: cost + math_ops.reduce_sum(embedding))) - return it + 1, cost - - _, cost = control_flow_ops.while_loop( - cond, body, [constant_op.constant(0), - constant_op.constant(0.0)]) - - dynamic_grads = gradients_impl.gradients(cost, [embedding_matrix])[0] - dynamic_grads = math_ops.segment_sum(dynamic_grads.values, - dynamic_grads.indices) - - embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) - static = math_ops.square( - math_ops.reduce_sum(embedding) + math_ops.reduce_sum(embedding) + - math_ops.reduce_sum(embedding)) + math_ops.reduce_sum(embedding) - static_grads = gradients_impl.gradients(static, [embedding_matrix])[0] - static_grads = math_ops.segment_sum(static_grads.values, - static_grads.indices) - - with self.cached_session(): - self.evaluate(variables.global_variables_initializer()) - self.assertAllEqual(*self.evaluate([static_grads, dynamic_grads])) - - def testIndexedSlicesGradientInCondInWhileLoop(self): - self.doTestIndexedSlicesGradientInCondInWhileLoop(use_resource=False) - - def testIndexedSlicesGradientInCondInWhileLoopResource(self): - self.doTestIndexedSlicesGradientInCondInWhileLoop(use_resource=True) - - @test_util.run_v1_only("b/120545219") - def testIndexedSlicesWithShapeGradientInWhileLoop(self): - for dtype in [dtypes.float32, dtypes.float64]: - with self.cached_session() as sess: - num_steps = 9 - - inputs = array_ops.placeholder(dtype=dtype, shape=[num_steps]) - initial_outputs = tensor_array_ops.TensorArray( - dtype=dtype, size=num_steps) - initial_i = constant_op.constant(0, dtype=dtypes.int32) - - def cond(i, _): - return i < num_steps # pylint: disable=cell-var-from-loop - - def body(i, outputs): - x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop - outputs = outputs.write(i, x) - return i + 1, outputs - - _, outputs = control_flow_ops.while_loop(cond, body, - [initial_i, initial_outputs]) - - outputs = math_ops.reduce_sum(outputs.stack()) - r = gradients_impl.gradients([outputs], [inputs])[0] - grad_wr_inputs = ops.convert_to_tensor(r) - o, grad = sess.run([outputs, grad_wr_inputs], - feed_dict={inputs: [4, 6, 0, 7, 0, 0, 1, 2, 0]}) - self.assertEquals(o, 20) - self.assertAllEqual(grad, [1] * num_steps) - - @test_util.run_v1_only("b/120545219") - def testIndexedSlicesWithDynamicShapeGradientInWhileLoop(self): - for dtype in [dtypes.float32, dtypes.float64]: - with self.cached_session() as sess: - inputs = array_ops.placeholder(dtype=dtype) - initial_outputs = tensor_array_ops.TensorArray( - dtype=dtype, dynamic_size=True, size=1) - initial_i = constant_op.constant(0, dtype=dtypes.int32) - - def cond(i, _): - return i < array_ops.size(inputs) # pylint: disable=cell-var-from-loop - - def body(i, outputs): - x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop - outputs = outputs.write(i, x) - return i + 1, outputs - - _, outputs = control_flow_ops.while_loop(cond, body, - [initial_i, initial_outputs]) - - outputs = math_ops.reduce_sum(outputs.stack()) - r = gradients_impl.gradients([outputs], [inputs])[0] - grad_wr_inputs = ops.convert_to_tensor(r) - o, grad = sess.run([outputs, grad_wr_inputs], - feed_dict={inputs: [1, 3, 2]}) - self.assertEquals(o, 6) - self.assertAllEqual(grad, [1] * 3) - - @test_util.run_deprecated_v1 - def testGradientThroughSingleBranchOutsideOfContext(self): - x = constant_op.constant(2.) - s = constant_op.constant(True) - x_false, x_true = control_flow_ops.switch(x, s) - grad_x_true = gradients_impl.gradients(x_true, x)[0] - grad_x_false = gradients_impl.gradients(x_false, x)[0] - self.assertEquals(self.evaluate(grad_x_true), 1.) - self.assertEquals(self.evaluate(grad_x_false), 0.) - - -class CondTest(test_util.TensorFlowTestCase): - - def testCondTrue(self): - x = constant_op.constant(2) - y = constant_op.constant(5) - z = control_flow_ops.cond( - math_ops.less( - x, - y), lambda: math_ops.multiply(x, 17), lambda: math_ops.add(y, 23)) - self.assertEquals(self.evaluate(z), 34) - - def testCondFalse(self): - x = constant_op.constant(2) - y = constant_op.constant(1) - z = control_flow_ops.cond( - math_ops.less( - x, - y), lambda: math_ops.multiply(x, 17), lambda: math_ops.add(y, 23)) - self.assertEquals(self.evaluate(z), 24) - - def testCondTrueLegacy(self): - x = constant_op.constant(2) - y = constant_op.constant(5) - z = control_flow_ops.cond( - math_ops.less(x, y), - fn1=lambda: math_ops.multiply(x, 17), - fn2=lambda: math_ops.add(y, 23)) - self.assertEquals(self.evaluate(z), 34) - - def testCondFalseLegacy(self): - x = constant_op.constant(2) - y = constant_op.constant(1) - z = control_flow_ops.cond( - math_ops.less(x, y), - fn1=lambda: math_ops.multiply(x, 17), - fn2=lambda: math_ops.add(y, 23)) - self.assertEquals(self.evaluate(z), 24) - - @test_util.run_deprecated_v1 - def testCondModifyBoolPred(self): - # This test in particular used to fail only when running in GPU, hence - # use_gpu=True. - with test_util.use_gpu(): - bool_var = variable_scope.get_variable( - "bool_var", dtype=dtypes.bool, initializer=True) - cond_on_bool_var = control_flow_ops.cond( - pred=bool_var, - true_fn=lambda: state_ops.assign(bool_var, False), - false_fn=lambda: True) - self.evaluate(bool_var.initializer) - self.assertEquals(self.evaluate(cond_on_bool_var), False) - self.assertEquals(self.evaluate(cond_on_bool_var), True) - - def testCondMissingArg1(self): - x = constant_op.constant(1) - with self.assertRaises(TypeError): - control_flow_ops.cond(True, false_fn=lambda: x) - - def testCondMissingArg2(self): - x = constant_op.constant(1) - with self.assertRaises(TypeError): - control_flow_ops.cond(True, lambda: x) - - def testCondDuplicateArg1(self): - x = constant_op.constant(1) - with self.assertRaises(TypeError): - control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x) - - def testCondDuplicateArg2(self): - x = constant_op.constant(1) - with self.assertRaises(TypeError): - control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x) - - -class ContextTest(test_util.TensorFlowTestCase): - - @test_util.run_deprecated_v1 - def testCondContext(self): - with self.cached_session() as sess: - x = constant_op.constant(2) - y = constant_op.constant(5) - control_flow_ops.cond( - math_ops.less(x, y), lambda: math_ops.multiply(x, 17), - lambda: math_ops.add(y, 23)) - for op in sess.graph.get_operations(): - c = op._get_control_flow_context() - if c: - self.assertProtoEquals( - c.to_proto(), - control_flow_ops.CondContext.from_proto(c.to_proto()).to_proto()) - - def _testWhileContextHelper(self, maximum_iterations=None): - with self.cached_session() as sess: - i = constant_op.constant(0) - c = lambda i: math_ops.less(i, 10) - b = lambda i: math_ops.add(i, 1) - control_flow_ops.while_loop( - c, b, [i], maximum_iterations=maximum_iterations) - for op in sess.graph.get_operations(): - control_flow_context = op._get_control_flow_context() - if control_flow_context: - self.assertProtoEquals( - control_flow_context.to_proto(), - control_flow_ops.WhileContext.from_proto( - control_flow_context.to_proto()).to_proto()) - - @test_util.run_deprecated_v1 - def testWhileContext(self): - self._testWhileContextHelper() - - @test_util.run_deprecated_v1 - def testWhileContextWithMaximumIterations(self): - self._testWhileContextHelper(maximum_iterations=10) - - @test_util.run_deprecated_v1 - def testControlContextImportScope(self): - class NoABCControlFlowContext(control_flow_ops.ControlFlowContext): - """A noop wrapper around `ControlFlowContext`. - - `ControlFlowContext` is an ABC and therefore cannot be instantiated. - """ - # pylint: disable=useless-super-delegation - - def to_control_flow_context_def(self, context_def, export_scope=None): - super(NoABCControlFlowContext, self).to_control_flow_context_def( - context_def, export_scope) - - with self.cached_session(): - constant_op.constant(0, name="a") - constant_op.constant(2, name="test_scope/a") - b1 = constant_op.constant(1, name="b") - b2 = constant_op.constant(3, name="test_scope/b") - - c = NoABCControlFlowContext() - c._values = ["a", "b"] - c._external_values = {"a": b1} - - c_with_scope = NoABCControlFlowContext( - values_def=c._to_values_def(), import_scope="test_scope") - - # _values and _external_values should be have scope prepended. - self.assertEquals( - c_with_scope._values, set(["test_scope/a", "test_scope/b"])) - self.assertEquals( - c_with_scope._external_values, {"test_scope/a": b2}) - - # Calling _to_proto() with export_scope should remove "test_scope". - self.assertProtoEquals( - c._to_values_def(), - c_with_scope._to_values_def(export_scope="test_scope")) - - -def _get_nested_shape(nested): - - def _get_shape(tensor): - if isinstance(tensor, tensor_array_ops.TensorArray): - return tensor_array_ops.TensorArray - elif isinstance(tensor, ops.IndexedSlices): - return tensor.dense_shape - else: - return tensor.get_shape() - - return nest.map_structure(_get_shape, nested) - - -def _create_tensor_array(size, shape): - ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=size, - clear_after_read=False) - for i in range(size): - ta = ta.write(i, array_ops.zeros(shape)) - return ta - - -def _raw_nested_shape(nested_shape): - - def _raw_shape(shape): - if isinstance(shape, tensor_shape.TensorShape) and shape.ndims is not None: - return [x.value for x in shape.dims] - else: - return None - - return nest.map_structure(_raw_shape, nested_shape) - - -# TODO(yori): Add tests for indexed slices. -class DataTypesTest(test_util.TensorFlowTestCase): - - def assertAllEqualNested(self, a, b): - if isinstance(a, (list, tuple)): - for entry_a, entry_b in zip(a, b): - self.assertAllEqualNested(entry_a, entry_b) - else: - self.assertAllEqual(a, b) - - def _testShape(self, fn_true, fn_false, expected_shape, - strict=False): - condition = array_ops.placeholder(dtypes.bool) - output_cond = control_flow_ops.cond(condition, fn_true, fn_false, - strict=strict) - self.assertEqual( - _raw_nested_shape(_get_nested_shape(output_cond)), - _raw_nested_shape(expected_shape)) - - output_case = control_flow_ops.case([(condition, fn_true)], fn_false, - strict=strict) - self.assertEqual( - _raw_nested_shape(_get_nested_shape(output_case)), - _raw_nested_shape(expected_shape)) - - def _testReturnValues(self, fn_true, fn_false, expected_value_true, - expected_value_false, strict=False, - check_cond=True, feed_dict=None): - if feed_dict is None: feed_dict = {} - - condition = array_ops.placeholder(dtypes.bool) - output_cond = control_flow_ops.cond(condition, fn_true, fn_false, - strict=strict) - output_case = control_flow_ops.case([(condition, fn_true)], fn_false, - strict=strict) - - with self.cached_session() as sess: - self.evaluate(variables.global_variables_initializer()) - true_feed_dict = {condition: True} - true_feed_dict.update(feed_dict) - result_cond, result_case = sess.run([output_cond, output_case], - feed_dict=true_feed_dict) - self.assertAllEqualNested(result_cond, expected_value_true) - if check_cond: - self.assertAllEqualNested(result_case, expected_value_true) - false_feed_dict = {condition: False} - false_feed_dict.update(feed_dict) - result_cond, result_case = sess.run([output_cond, output_case], - feed_dict=false_feed_dict) - self.assertAllEqualNested(result_cond, expected_value_false) - if check_cond: - self.assertAllEqualNested(result_case, expected_value_false) - - @test_util.run_deprecated_v1 - def test_int(self): - shape = tensor_shape.TensorShape([]) - fn_true = lambda: 1 - fn_false = lambda: 2 - self._testShape(fn_true, fn_false, shape) - self._testReturnValues(fn_true, fn_false, 1, 2) - self._testShape(fn_true, fn_false, shape, strict=True) - self._testReturnValues(fn_true, fn_false, 1, 2, strict=True) - - @test_util.run_deprecated_v1 - def test_float(self): - shape = tensor_shape.TensorShape([]) - fn_true = lambda: 1.0 - fn_false = lambda: 2.0 - self._testShape(fn_true, fn_false, shape) - self._testReturnValues(fn_true, fn_false, 1.0, 2.0) - - @test_util.run_deprecated_v1 - def test_noop(self): - shape = tensor_shape.TensorShape(None) - self._testShape(control_flow_ops.no_op, control_flow_ops.no_op, shape) - self._testReturnValues(control_flow_ops.no_op, control_flow_ops.no_op, - True, False, check_cond=False) - - @test_util.run_deprecated_v1 - def test_string(self): - shape = tensor_shape.TensorShape([]) - fn_true = lambda: "abc" - fn_false = lambda: "xyz" - self._testShape(fn_true, fn_false, shape) - self._testReturnValues(fn_true, fn_false, b"abc", b"xyz") - - @test_util.run_deprecated_v1 - def test_variable(self): - shape = tensor_shape.TensorShape([]) - fn_true = lambda: variables.Variable(3.0) - fn_false = lambda: variables.Variable(4.0) - self._testShape(fn_true, fn_false, shape) - self._testReturnValues(fn_true, fn_false, 3.0, 4.0) - - @test_util.run_v1_only("b/120553181") - def test_none(self): - fn_none = lambda: None - fn_tensor = lambda: constant_op.constant(1) - - with self.assertRaises(ValueError): - control_flow_ops.cond(constant_op.constant(True), fn_none, fn_tensor) - - with self.assertRaises(ValueError): - control_flow_ops.cond(constant_op.constant(True), fn_tensor, fn_none) - - @test_util.run_deprecated_v1 - def test_tensors(self): - - def _build_true_branch(dtype): - - def _build(): - return (array_ops.zeros([2, 2], dtype=dtype), - array_ops.ones([3, 3], dtype=dtype)) - - return _build - - def _build_false_branch(dtype): - - def _build(): - return (array_ops.ones([2, 2], dtype=dtype), - array_ops.zeros([3, 3], dtype=dtype)) - - return _build - - for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8): - shape = (tensor_shape.TensorShape([2, 2]), - tensor_shape.TensorShape([3, 3])) - fn_true = _build_true_branch(dtype) - fn_false = _build_false_branch(dtype) - self._testShape(fn_true, fn_false, shape) - self._testReturnValues(fn_true, fn_false, - (np.zeros([2, 2]), np.ones([3, 3])), - (np.ones([2, 2]), np.zeros([3, 3]))) - - @test_util.run_deprecated_v1 - def test_tensors_unknown_shape(self): - - def _build_true_branch(dtype): - tensor = array_ops.placeholder(dtype=dtype, shape=None) - - def _build(): - return tensor - - return _build, tensor - - def _build_false_branch(dtype): - tensor = array_ops.placeholder(dtype=dtype, shape=None) - - def _build(): - return tensor - - return _build, tensor - - for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8): - shape = tensor_shape.TensorShape(None) - fn_true, true_tensor = _build_true_branch(dtype) - fn_false, false_tensor = _build_false_branch(dtype) - self._testShape(fn_true, fn_false, shape) - self._testReturnValues(fn_true, fn_false, - np.zeros([2, 2]), np.ones([2, 2]), - feed_dict={true_tensor: np.zeros([2, 2]), - false_tensor: np.ones([2, 2])}) - - @test_util.run_deprecated_v1 - def test_sparse_tensors(self): - shape = tensor_shape.TensorShape([None, None]) - - def true_fn(): - return [sparse_tensor.SparseTensor(indices=[[0, 0], [1, 2]], - values=[1, 2], dense_shape=[3, 4])] - - def false_fn(): - return [sparse_tensor.SparseTensor(indices=[[0, 0], [2, 1]], - values=[3, 4], dense_shape=[3, 4])] - - value1 = sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 2]], - values=[1, 2], dense_shape=[3, 4]) - value2 = sparse_tensor.SparseTensorValue(indices=[[0, 0], [2, 1]], - values=[3, 4], dense_shape=[3, 4]) - # Non-strict cond is only available in v1 - if not tf2.enabled(): - self._testShape(true_fn, false_fn, shape) - self._testReturnValues(true_fn, false_fn, value1, value2) - self._testShape(true_fn, false_fn, [shape], strict=True) - self._testReturnValues(true_fn, false_fn, [value1], [value2], strict=True) - - @test_util.run_deprecated_v1 - def test_tensors_with_partially_specified_shapes(self): - - def _build_branch(dtype, shape): - a = array_ops.placeholder(dtype=dtype, shape=shape[0]) - b = array_ops.placeholder(dtype=dtype, shape=shape[1]) - c = array_ops.placeholder(dtype=dtype, shape=shape[2]) - - def _build(): - return a, b, c - - return _build, (a, b, c) - - for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8): - shape = (tensor_shape.TensorShape([None, 2]), - tensor_shape.TensorShape([None]), - tensor_shape.TensorShape([3, None])) - fn_true, true_tensors = _build_branch(dtype, shape) - fn_false, false_tensors = _build_branch(dtype, shape) - self._testShape(fn_true, fn_false, shape) - self._testReturnValues(fn_true, fn_false, - (np.zeros([2, 2]), np.zeros(5), np.ones([3, 3])), - (np.zeros([2, 2]), np.zeros(5), np.ones([3, 3])), - feed_dict={true_tensors[0]: np.zeros([2, 2]), - false_tensors[0]: np.zeros([2, 2]), - true_tensors[1]: np.zeros([5]), - false_tensors[1]: np.zeros([5]), - true_tensors[2]: np.ones([3, 3]), - false_tensors[2]: np.ones([3, 3])}) - - @test_util.run_deprecated_v1 - def test_tensor_arrays(self): - element_shape = tensor_shape.TensorShape([2]) - ta1 = _create_tensor_array(4, element_shape) - ta2 = _create_tensor_array(4, element_shape) - shape = tensor_array_ops.TensorArray - fn_true = lambda: ta1 - fn_false = lambda: ta2 - self._testShape(fn_true, fn_false, shape) - - @test_util.run_deprecated_v1 - def test_tensor_array_reads(self): - shape = tensor_shape.TensorShape([2]) - ta = _create_tensor_array(4, shape) - fn_true = lambda: ta.read(0) - fn_false = lambda: ta.read(1) - self._testShape(fn_true, fn_false, shape) - - @test_util.run_deprecated_v1 - def test_list(self): - shape = [tensor_shape.TensorShape([]), tensor_shape.TensorShape([]), - tensor_shape.TensorShape([])] - fn_true = lambda: [constant_op.constant(1), 2, variables.Variable(3.0)] - fn_false = lambda: [constant_op.constant(3), 4, variables.Variable(5.0)] - self._testShape(fn_true, fn_false, shape) - self._testReturnValues(fn_true, fn_false, [1, 2, 3.0], [3, 4, 5.0]) - - @test_util.run_v1_only("Non-strict cond is only available in v1") - def test_non_strict(self): - shape = tensor_shape.TensorShape([]) - fn_tensor = lambda: constant_op.constant(1) - fn_list = lambda: [constant_op.constant(2)] - fn_tuple = lambda: (constant_op.constant(3),) - self._testShape(fn_tensor, fn_list, shape) - self._testShape(fn_tensor, fn_tuple, shape) - self._testShape(fn_list, fn_tuple, shape) - self._testReturnValues(fn_tensor, fn_list, 1, 2) - self._testReturnValues(fn_tensor, fn_tuple, 1, 3) - self._testReturnValues(fn_list, fn_tuple, 2, 3) - - @test_util.run_v1_only("b/120553181") - def test_singleton_strict(self): - fn_tensor = lambda: constant_op.constant(1) - fn_list = lambda: [constant_op.constant(2)] - fn_tuple = lambda: (constant_op.constant(3),) - - with self.assertRaises(ValueError): - control_flow_ops.cond(constant_op.constant(True), fn_tensor, fn_list, - strict=True) - - with self.assertRaises(TypeError): - control_flow_ops.cond(constant_op.constant(True), fn_list, fn_tuple, - strict=True) - - with self.assertRaises(ValueError): - control_flow_ops.case([(constant_op.constant(True), fn_tensor)], fn_list, - strict=True) - - with self.assertRaises(TypeError): - control_flow_ops.case([(constant_op.constant(True), fn_list)], fn_tuple, - strict=True) - - @test_util.run_deprecated_v1 - def test_singleton_list(self): - shape = tensor_shape.TensorShape([]) - fn_true = lambda: [constant_op.constant(1)] - fn_false = lambda: [constant_op.constant(3)] - # Non-strict cond is only available in v1 - if not tf2.enabled(): - self._testShape(fn_true, fn_false, shape) - self._testReturnValues(fn_true, fn_false, 1, 3) - self._testShape(fn_true, fn_false, [shape], strict=True) - self._testReturnValues(fn_true, fn_false, [1], [3], strict=True) - - @test_util.run_deprecated_v1 - def test_singleton_tuple(self): - shape = tensor_shape.TensorShape([]) - fn_true = lambda: (constant_op.constant(1),) - fn_false = lambda: (constant_op.constant(3),) - # Non-strict cond is only available in v1 - if not tf2.enabled(): - self._testShape(fn_true, fn_false, shape) - self._testReturnValues(fn_true, fn_false, 1, 3) - self._testShape(fn_true, fn_false, (shape,), strict=True) - self._testReturnValues(fn_true, fn_false, (1,), (3,), - strict=True) - - @test_util.run_deprecated_v1 - def test_singleton_namedtuple(self): - shape = tensor_shape.TensorShape([]) - fn_true = lambda: SingletonTestTuple(constant_op.constant(1)) - fn_false = lambda: SingletonTestTuple(constant_op.constant(3)) - # Non-strict cond is only available in v1 - if not tf2.enabled(): - self._testShape(fn_true, fn_false, shape) - self._testReturnValues(fn_true, fn_false, 1, 3) - self._testShape(fn_true, fn_false, SingletonTestTuple(shape), - strict=True) - self._testReturnValues(fn_true, fn_false, SingletonTestTuple(1), - SingletonTestTuple(3), strict=True) - - @test_util.run_deprecated_v1 - def test_tuple(self): - shape = (tensor_shape.TensorShape([]), tensor_shape.TensorShape([])) - fn_true = lambda: (constant_op.constant(1), 2) - fn_false = lambda: (constant_op.constant(3), 4) - self._testShape(fn_true, fn_false, shape) - self._testReturnValues(fn_true, fn_false, (1, 2), (3, 4)) - - @test_util.run_deprecated_v1 - def test_namedtuple(self): - shape = TestTuple(tensor_shape.TensorShape([]), - tensor_shape.TensorShape([])) - fn_true = lambda: TestTuple(constant_op.constant(1), 2) - fn_false = lambda: TestTuple(constant_op.constant(3), 4) - self._testShape(fn_true, fn_false, shape) - self._testReturnValues(fn_true, fn_false, TestTuple(1, 2), TestTuple(3, 4)) - - @test_util.run_deprecated_v1 - def test_nested(self): - shape = [tensor_shape.TensorShape([]), - TestTuple(tensor_shape.TensorShape([]), - [tensor_shape.TensorShape([]), - tensor_shape.TensorShape([])]), - tensor_shape.TensorShape([5, 5]), - tensor_shape.TensorShape([])] - - def true_fn(): - return [constant_op.constant(1), - TestTuple(constant_op.constant(2), [3, 4]), - array_ops.zeros([5, 5]), 6] - - def false_fn(): - return [constant_op.constant(11), - TestTuple(constant_op.constant(12), [13, 14]), - array_ops.ones([5, 5]), 16] - - self._testShape(true_fn, false_fn, shape) - self._testReturnValues( - true_fn, false_fn, - [1, TestTuple(2, [3, 4]), np.zeros([5, 5]), 6], - [11, TestTuple(12, [13, 14]), - np.ones([5, 5]), 16]) - - @test_util.run_deprecated_v1 - def test_cond_inside_while_loop(self): - - def body(i, matrix): - result_tuple, unused_matrix = control_flow_ops.cond( - constant_op.constant(True), - lambda: (TestTuple(matrix * 2, matrix * 4), matrix), - lambda: (TestTuple(matrix * 4, matrix * 2), matrix)) - return [i+1, result_tuple.a] - - iteration, matrix = control_flow_ops.while_loop( - lambda i, matrix: i < 10, - body, - loop_vars=[constant_op.constant(0), - array_ops.ones([2, 2])]) - - self.assertEqual(iteration.get_shape(), tensor_shape.TensorShape([])) - self.assertEqual(matrix.get_shape(), tensor_shape.TensorShape([2, 2])) - - -class CaseTest(test_util.TensorFlowTestCase): - - @test_util.run_deprecated_v1 - def testCase_withDefault(self): - x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) - conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)), - (math_ops.equal(x, 2), lambda: constant_op.constant(4))] - default = lambda: constant_op.constant(6) - output = control_flow_ops.case(conditions, default, exclusive=True) - with self.cached_session() as sess: - self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) - self.assertEqual(sess.run(output, feed_dict={x: 2}), 4) - self.assertEqual(sess.run(output, feed_dict={x: 3}), 6) - - @test_util.run_deprecated_v1 - def testCase_multiple_matches_exclusive(self): - x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) - conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)), - (math_ops.equal(x, 2), lambda: constant_op.constant(4)), - (math_ops.equal(x, 2), lambda: constant_op.constant(6))] - default = lambda: constant_op.constant(8) - output = control_flow_ops.case(conditions, default, exclusive=True) - with self.cached_session() as sess: - self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) - self.assertEqual(sess.run(output, feed_dict={x: 3}), 8) - with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"): - sess.run(output, feed_dict={x: 2}) - - @test_util.run_deprecated_v1 - def testCase_multiple_matches_non_exclusive(self): - x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) - conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)), - (math_ops.equal(x, 2), lambda: constant_op.constant(4)), - (math_ops.equal(x, 2), lambda: constant_op.constant(6))] - default = lambda: constant_op.constant(8) - output = control_flow_ops.case(conditions, default, exclusive=False) - with self.cached_session() as sess: - self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) - self.assertEqual(sess.run(output, feed_dict={x: 2}), 4) - self.assertEqual(sess.run(output, feed_dict={x: 3}), 8) - - @test_util.run_deprecated_v1 - def testCase_withoutDefault(self): - x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) - conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)), - (math_ops.equal(x, 2), lambda: constant_op.constant(4)), - (math_ops.equal(x, 3), lambda: constant_op.constant(6))] - output = control_flow_ops.case(conditions, exclusive=True) - with self.cached_session() as sess: - self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) - self.assertEqual(sess.run(output, feed_dict={x: 2}), 4) - self.assertEqual(sess.run(output, feed_dict={x: 3}), 6) - with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"): - sess.run(output, feed_dict={x: 4}) - - @test_util.run_deprecated_v1 - def testCase_withoutDefault_oneCondition(self): - x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) - conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2))] - output = control_flow_ops.case(conditions, exclusive=True) - with self.cached_session() as sess: - self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) - with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"): - sess.run(output, feed_dict={x: 4}) - - @test_util.run_in_graph_and_eager_modes - def testCase_dict(self): - x = constant_op.constant(2) - conditions = { - math_ops.equal(x, 1): lambda: constant_op.constant(2), - math_ops.equal(x, 2): lambda: constant_op.constant(4) - } - output = control_flow_ops.case(conditions, exclusive=True) - self.assertEqual(4, self.evaluate(output)) - - -class WhileLoopTestCase(test_util.TensorFlowTestCase): - - @test_util.run_in_graph_and_eager_modes - def testWhileLoopWithSingleVariable(self): - i = constant_op.constant(0) - c = lambda i: math_ops.less(i, 10) - b = lambda i: math_ops.add(i, 1) - r = control_flow_ops.while_loop(c, b, [i]) - - self.assertEqual(self.evaluate(r), 10) - - @test_util.run_in_graph_and_eager_modes - def testEagerWhileLoopWithSingleVariable_bodyReturnsTuple(self): - i = constant_op.constant(0) - c = lambda i: math_ops.less(i, 10) - b = lambda i: (math_ops.add(i, 1),) - r = control_flow_ops.while_loop(c, b, [i]) - - # Expect a tuple since that is what the body returns. - self.assertEqual(self.evaluate(r), (10,)) - - @test_util.run_deprecated_v1 - def testWhileLoopSameReturnShape_False(self): - i = constant_op.constant(0) - c = lambda i, _: math_ops.less(i, 10) - - # Body returns a [tensor, []] - b = lambda i, _: [math_ops.add(i, 1), []] - - # Should only return the tensor. - r = control_flow_ops.while_loop(c, b, [i, []]) - self.assertEqual(self.evaluate(r), 10) - - def testWhileLoopSameReturnShape_True(self): - i = constant_op.constant(0) - c = lambda i, _: math_ops.less(i, 10) - - # Body returns a [tensor, []] - b = lambda i, _: [math_ops.add(i, 1), []] - - # Should only return the original structure. - r = control_flow_ops.while_loop(c, b, [i, []], return_same_structure=True) - self.assertEqual(self.evaluate(r), [10, []]) - - -class AssertTest(test_util.TensorFlowTestCase): - - @test_util.run_deprecated_v1 - def testAssert(self): - i = constant_op.constant(0) - c = control_flow_ops.Assert(i < 10, [i, [10], [i + 1]]) - self.evaluate(c) - - i = constant_op.constant(10) - c = control_flow_ops.Assert(i < 10, [i, [10], [i + 1]]) - with self.assertRaises(errors.InvalidArgumentError): - self.evaluate(c) - - @test_util.run_in_graph_and_eager_modes - def testAssertInFunction(self): - - @def_function.function - def whiny(value): - control_flow_ops.Assert(value, ["Raised false"]) - return constant_op.constant(5) - - with self.assertRaises(errors.InvalidArgumentError): - self.evaluate(whiny(False)) - - self.assertAllEqual(whiny(True), 5) - -if __name__ == "__main__": - googletest.main() diff --git a/test/TensorFlowNET.UnitTest/nest_test/NestTest.cs b/test/TensorFlowNET.UnitTest/nest_test/NestTest.cs deleted file mode 100644 index c8928fa9..00000000 --- a/test/TensorFlowNET.UnitTest/nest_test/NestTest.cs +++ /dev/null @@ -1,875 +0,0 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; -using Newtonsoft.Json.Linq; -using NumSharp; -using System; -using System.Collections; -using System.Collections.Generic; -using Tensorflow; -using Tensorflow.UnitTest; -using Tensorflow.Util; -using static Tensorflow.Binding; - -namespace TensorFlowNET.UnitTest.nest_test -{ - /// - /// excerpt of tensorflow/python/framework/util/nest_test.py - /// - [TestClass] - public class NestTest : GraphModeTestBase - { - [TestInitialize] - public void TestInitialize() - { - tf.Graph().as_default(); - } - - //public class PointXY - //{ - // public double x; - // public double y; - //} - - // if attr: - // class BadAttr(object): - // """Class that has a non-iterable __attrs_attrs__.""" - // __attrs_attrs__ = None - - // @attr.s - // class SampleAttr(object): - // field1 = attr.ib() - // field2 = attr.ib() - - // @test_util.assert_no_new_pyobjects_executing_eagerly - // def testAttrsFlattenAndPack(self) : - // if attr is None: - // self.skipTest("attr module is unavailable.") - - // field_values = [1, 2] - // sample_attr = NestTest.SampleAttr(* field_values) - // self.assertFalse(nest._is_attrs(field_values)) - // self.assertTrue(nest._is_attrs(sample_attr)) - // flat = nest.flatten(sample_attr) - // self.assertEqual(field_values, flat) - // restructured_from_flat = nest.pack_sequence_as(sample_attr, flat) - // self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr) - // self.assertEqual(restructured_from_flat, sample_attr) - - //# Check that flatten fails if attributes are not iterable - // with self.assertRaisesRegexp(TypeError, "object is not iterable"): - // flat = nest.flatten(NestTest.BadAttr()) - [Ignore] - [TestMethod] - public void testFlattenAndPack() - { - object structure = new object[] { new object[] { 3, 4 }, 5, new object[] { 6, 7, new object[] { 9, 10 }, 8 } }; - var flat = new List { "a", "b", "c", "d", "e", "f", "g", "h" }; - - self.assertEqual(nest.flatten(structure), new[] { 3, 4, 5, 6, 7, 9, 10, 8 }); - self.assertEqual(JArray.FromObject(nest.pack_sequence_as(structure, flat)).ToString(), - JArray.FromObject(new object[] { new object[] { "a", "b" }, "c", new object[] { "d", "e", new object[] { "f", "g" }, "h" } }).ToString()); - structure = new object[] { new Hashtable { ["x"] = 4, ["y"] = 2 }, new object[] { new object[] { new Hashtable { ["x"] = 1, ["y"] = 0 }, }, } }; - flat = new List { 4, 2, 1, 0 }; - self.assertEqual(nest.flatten(structure), flat); - var restructured_from_flat = nest.pack_sequence_as(structure, flat) as object[]; - //Console.WriteLine(JArray.FromObject(restructured_from_flat)); - self.assertEqual(restructured_from_flat, structure); - self.assertEqual((restructured_from_flat[0] as Hashtable)["x"], 4); - self.assertEqual((restructured_from_flat[0] as Hashtable)["y"], 2); - self.assertEqual((((restructured_from_flat[1] as object[])[0] as object[])[0] as Hashtable)["x"], 1); - self.assertEqual((((restructured_from_flat[1] as object[])[0] as object[])[0] as Hashtable)["y"], 0); - - self.assertEqual(new List { 5 }, nest.flatten(5)); - var flat1 = nest.flatten(np.array(new[] { 5 })); - self.assertEqual(new object[] { np.array(new int[] { 5 }) }, flat1); - - self.assertEqual("a", nest.pack_sequence_as(5, new List { "a" })); - self.assertEqual(np.array(new[] { 5 }), - nest.pack_sequence_as("scalar", new List { np.array(new[] { 5 }) })); - - Assert.ThrowsException(() => nest.pack_sequence_as("scalar", new List() { 4, 5 })); - - Assert.ThrowsException(() => - nest.pack_sequence_as(new object[] { 5, 6, new object[] { 7, 8 } }, new List { "a", "b", "c" })); - } - - // @parameterized.parameters({"mapping_type": collections.OrderedDict - // }, - // {"mapping_type": _CustomMapping - //}) - // @test_util.assert_no_new_pyobjects_executing_eagerly - // def testFlattenDictOrder(self, mapping_type) : - // """`flatten` orders dicts by key, including OrderedDicts.""" - // ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) - // plain = {"d": 3, "b": 1, "a": 0, "c": 2} - // ordered_flat = nest.flatten(ordered) - // plain_flat = nest.flatten(plain) - // self.assertEqual([0, 1, 2, 3], ordered_flat) - // self.assertEqual([0, 1, 2, 3], plain_flat) - - // @parameterized.parameters({"mapping_type": collections.OrderedDict}, - // {"mapping_type": _CustomMapping}) - // def testPackDictOrder(self, mapping_type): - // """Packing orders dicts by key, including OrderedDicts.""" - // custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)]) - // plain = {"d": 0, "b": 0, "a": 0, "c": 0} - // seq = [0, 1, 2, 3] - //custom_reconstruction = nest.pack_sequence_as(custom, seq) - //plain_reconstruction = nest.pack_sequence_as(plain, seq) - // self.assertIsInstance(custom_reconstruction, mapping_type) - // self.assertIsInstance(plain_reconstruction, dict) - // self.assertEqual( - // mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]), - // custom_reconstruction) - // self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction) - - // Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name - - // @test_util.assert_no_new_pyobjects_executing_eagerly - // def testFlattenAndPack_withDicts(self) : - // # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. - // mess = [ - // "z", - // NestTest.Abc(3, 4), { - // "d": _CustomMapping({ - // 41: 4 - // }), - // "c": [ - // 1, - // collections.OrderedDict([ - // ("b", 3), - // ("a", 2), - // ]), - // ], - // "b": 5 - // }, 17 - // ] - - // flattened = nest.flatten(mess) - // self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17]) - - // structure_of_mess = [ - // 14, - // NestTest.Abc("a", True), - // { - // "d": _CustomMapping({ - // 41: 42 - // }), - // "c": [ - // 0, - // collections.OrderedDict([ - // ("b", 9), - // ("a", 8), - // ]), - // ], - // "b": 3 - // }, - // "hi everybody", - // ] - - // unflattened = nest.pack_sequence_as(structure_of_mess, flattened) - // self.assertEqual(unflattened, mess) - - // # Check also that the OrderedDict was created, with the correct key order. - //unflattened_ordered_dict = unflattened[2]["c"][1] - // self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict) - // self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"]) - - // unflattened_custom_mapping = unflattened[2]["d"] - // self.assertIsInstance(unflattened_custom_mapping, _CustomMapping) - // self.assertEqual(list(unflattened_custom_mapping.keys()), [41]) - - [TestMethod] - public void testFlatten_numpyIsNotFlattened() - { - var structure = np.array(1, 2, 3); - var flattened = nest.flatten(structure); - self.assertEqual(len(flattened), 1); - } - - [TestMethod] - public void testFlatten_stringIsNotFlattened() - { - var structure = "lots of letters"; - var flattened = nest.flatten(structure); - self.assertEqual(len(flattened), 1); - var unflattened = nest.pack_sequence_as("goodbye", flattened); - self.assertEqual(structure, unflattened); - } - - // def testPackSequenceAs_notIterableError(self) : - // with self.assertRaisesRegexp(TypeError, - // "flat_sequence must be a sequence"): - // nest.pack_sequence_as("hi", "bye") - - [TestMethod] - public void testPackSequenceAs_wrongLengthsError() - { - Assert.ThrowsException(() => - { - // with self.assertRaisesRegexp( - // ValueError, - // "Structure had 2 elements, but flat_sequence had 3 elements."): - nest.pack_sequence_as(new object[] { "hello", "world" }, new object[] { "and", "goodbye", "again" }); - }); - } - - [Ignore] - [TestMethod] - public void testIsSequence() - { - self.assertFalse(nest.is_sequence("1234")); - self.assertTrue(nest.is_sequence(new object[] { 1, 3, new object[] { 4, 5 } })); - // TODO: ValueTuple - //self.assertTrue(nest.is_sequence(((7, 8), (5, 6)))); - self.assertTrue(nest.is_sequence(new object[] { })); - self.assertTrue(nest.is_sequence(new Hashtable { ["a"] = 1, ["b"] = 2 })); - self.assertFalse(nest.is_sequence(new HashSet { 1, 2 })); - var ones = array_ops.ones(new int[] { 2, 3 }); - self.assertFalse(nest.is_sequence(ones)); - self.assertFalse(nest.is_sequence(gen_math_ops.tanh(ones))); - self.assertFalse(nest.is_sequence(np.ones(new int[] { 4, 5 }))); - } - - // @parameterized.parameters({"mapping_type": _CustomMapping}, - // {"mapping_type": dict}) - // def testFlattenDictItems(self, mapping_type): - // dictionary = mapping_type({ (4, 5, (6, 8)): ("a", "b", ("c", "d"))}) - // flat = {4: "a", 5: "b", 6: "c", 8: "d"} - // self.assertEqual(nest.flatten_dict_items(dictionary), flat) - - // with self.assertRaises(TypeError): - // nest.flatten_dict_items(4) - - // bad_dictionary = mapping_type({ (4, 5, (4, 8)): ("a", "b", ("c", "d"))}) - // with self.assertRaisesRegexp(ValueError, "not unique"): - // nest.flatten_dict_items(bad_dictionary) - - // another_bad_dictionary = mapping_type({ - // (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e"))) - // }) - // with self.assertRaisesRegexp( - // ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"): - // nest.flatten_dict_items(another_bad_dictionary) - - //# pylint does not correctly recognize these as class names and - //# suggests to use variable style under_score naming. - //# pylint: disable=invalid-name - // Named0ab = collections.namedtuple("named_0", ("a", "b")) - // Named1ab = collections.namedtuple("named_1", ("a", "b")) - // SameNameab = collections.namedtuple("same_name", ("a", "b")) - // SameNameab2 = collections.namedtuple("same_name", ("a", "b")) - // SameNamexy = collections.namedtuple("same_name", ("x", "y")) - // SameName1xy = collections.namedtuple("same_name_1", ("x", "y")) - // SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y")) - // NotSameName = collections.namedtuple("not_same_name", ("a", "b")) - // # pylint: enable=invalid-name - - // class SameNamedType1(SameNameab): - // pass - - // @test_util.assert_no_new_pyobjects_executing_eagerly - // def testAssertSameStructure(self): - // structure1 = (((1, 2), 3), 4, (5, 6)) - // structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) - // structure_different_num_elements = ("spam", "eggs") - // structure_different_nesting = (((1, 2), 3), 4, 5, (6,)) - // nest.assert_same_structure(structure1, structure2) - // nest.assert_same_structure("abc", 1.0) - // nest.assert_same_structure("abc", np.array([0, 1])) - // nest.assert_same_structure("abc", constant_op.constant([0, 1])) - - // with self.assertRaisesRegexp( - // ValueError, - // ("The two structures don't have the same nested structure\\.\n\n" - // "First structure:.*?\n\n" - // "Second structure:.*\n\n" - // "More specifically: Substructure " - // r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while ' - // 'substructure "type=str str=spam" is not\n' - // "Entire first structure:\n" - // r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n" - // "Entire second structure:\n" - // r"\(\., \.\)")): - // nest.assert_same_structure(structure1, structure_different_num_elements) - - // with self.assertRaisesRegexp( - // ValueError, - // ("The two structures don't have the same nested structure\\.\n\n" - // "First structure:.*?\n\n" - // "Second structure:.*\n\n" - // r'More specifically: Substructure "type=list str=\[0, 1\]" ' - // r'is a sequence, while substructure "type=ndarray str=\[0 1\]" ' - // "is not")): - // nest.assert_same_structure([0, 1], np.array([0, 1])) - - // with self.assertRaisesRegexp( - // ValueError, - // ("The two structures don't have the same nested structure\\.\n\n" - // "First structure:.*?\n\n" - // "Second structure:.*\n\n" - // r'More specifically: Substructure "type=list str=\[0, 1\]" ' - // 'is a sequence, while substructure "type=int str=0" ' - // "is not")): - // nest.assert_same_structure(0, [0, 1]) - - // self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1]) - - // with self.assertRaisesRegexp( - // ValueError, - // ("don't have the same nested structure\\.\n\n" - // "First structure: .*?\n\nSecond structure: ")): - // nest.assert_same_structure(structure1, structure_different_nesting) - - // self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), - // NestTest.Named0ab("a", "b")) - - // nest.assert_same_structure(NestTest.Named0ab(3, 4), - // NestTest.Named0ab("a", "b")) - - // self.assertRaises(TypeError, nest.assert_same_structure, - // NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4)) - - // with self.assertRaisesRegexp( - // ValueError, - // ("don't have the same nested structure\\.\n\n" - // "First structure: .*?\n\nSecond structure: ")): - // nest.assert_same_structure(NestTest.Named0ab(3, 4), - // NestTest.Named0ab([3], 4)) - - // with self.assertRaisesRegexp( - // ValueError, - // ("don't have the same nested structure\\.\n\n" - // "First structure: .*?\n\nSecond structure: ")): - // nest.assert_same_structure([[3], 4], [3, [4]]) - - // structure1_list = [[[1, 2], 3], 4, [5, 6]] - // with self.assertRaisesRegexp(TypeError, - // "don't have the same sequence type"): - // nest.assert_same_structure(structure1, structure1_list) - // nest.assert_same_structure(structure1, structure2, check_types= False) - // nest.assert_same_structure(structure1, structure1_list, check_types=False) - - // with self.assertRaisesRegexp(ValueError, - // "don't have the same set of keys"): - // nest.assert_same_structure({"a": 1}, {"b": 1}) - - // nest.assert_same_structure(NestTest.SameNameab(0, 1), - // NestTest.SameNameab2(2, 3)) - - // # This assertion is expected to pass: two namedtuples with the same - // # name and field names are considered to be identical. - // nest.assert_same_structure( - // NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2), - // NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4)) - - // expected_message = "The two structures don't have the same.*" - // with self.assertRaisesRegexp(ValueError, expected_message): - // nest.assert_same_structure( - // NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)), - // NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2)) - - // self.assertRaises(TypeError, nest.assert_same_structure, - // NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3)) - - // self.assertRaises(TypeError, nest.assert_same_structure, - // NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3)) - - // self.assertRaises(TypeError, nest.assert_same_structure, - // NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3)) - - // EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name - - // def testHeterogeneousComparison(self): - // nest.assert_same_structure({"a": 4}, _CustomMapping(a= 3)) - // nest.assert_same_structure(_CustomMapping(b=3), {"b": 4}) - [Ignore] - [TestMethod] - public void testMapStructure() - { - var structure1 = new object[] { new object[] { new object[] { 1, 2 }, 3 }, 4, new object[] { 5, 6 } }; - var structure2 = new object[] { new object[] { new object[] { 7, 8 }, 9 }, 10, new object[] { 11, 12 } }; - var structure1_plus1 = nest.map_structure(x => (int)x + 1, structure1); - var structure1_strings = nest.map_structure(x => $"{x}", structure1); - var s = JArray.FromObject(structure1_plus1).ToString(); - Console.WriteLine(s); - // nest.assert_same_structure(structure1, structure1_plus1) - self.assertAllEqual(nest.flatten(structure1_plus1), new object[] { 2, 3, 4, 5, 6, 7 }); - self.assertAllEqual(nest.flatten(structure1_strings), new object[] { "1", "2", "3", "4", "5", "6" }); - var structure1_plus_structure2 = nest.map_structure(x => (int)(x[0]) + (int)(x[1]), structure1, structure2); - self.assertEqual( - new object[] { new object[] { new object[] { 1 + 7, 2 + 8 }, 3 + 9 }, 4 + 10, new object[] { 5 + 11, 6 + 12 } }, - structure1_plus_structure2); - - // self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4)) - - // self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4)) - - // # Empty structures - // self.assertEqual((), nest.map_structure(lambda x: x + 1, ())) - // self.assertEqual([], nest.map_structure(lambda x: x + 1, [])) - // self.assertEqual({}, nest.map_structure(lambda x: x + 1, {})) - // self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1, - // NestTest.EmptyNT())) - - // # This is checking actual equality of types, empty list != empty tuple - // self.assertNotEqual((), nest.map_structure(lambda x: x + 1, [])) - - // with self.assertRaisesRegexp(TypeError, "callable"): - // nest.map_structure("bad", structure1_plus1) - - // with self.assertRaisesRegexp(ValueError, "at least one structure"): - // nest.map_structure(lambda x: x) - - // with self.assertRaisesRegexp(ValueError, "same number of elements"): - // nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5)) - - // with self.assertRaisesRegexp(ValueError, "same nested structure"): - // nest.map_structure(lambda x, y: None, 3, (3,)) - - // with self.assertRaisesRegexp(TypeError, "same sequence type"): - // nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5]) - - // with self.assertRaisesRegexp(ValueError, "same nested structure"): - // nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5))) - - // structure1_list = [[[1, 2], 3], 4, [5, 6]] - // with self.assertRaisesRegexp(TypeError, "same sequence type"): - // nest.map_structure(lambda x, y: None, structure1, structure1_list) - - // nest.map_structure(lambda x, y: None, structure1, structure1_list, - // check_types=False) - - // with self.assertRaisesRegexp(ValueError, "same nested structure"): - // nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)), - // check_types=False) - - // with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): - // nest.map_structure(lambda x: None, structure1, foo="a") - - // with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): - // nest.map_structure(lambda x: None, structure1, check_types=False, foo="a") - - // ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name - } - - // @test_util.assert_no_new_pyobjects_executing_eagerly - // def testMapStructureWithStrings(self) : - // inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz")) - // inp_b = NestTest.ABTuple(a=2, b=(1, 3)) - // out = nest.map_structure(lambda string, repeats: string* repeats, - // inp_a, - // inp_b) - // self.assertEqual("foofoo", out.a) - // self.assertEqual("bar", out.b[0]) - // self.assertEqual("bazbazbaz", out.b[1]) - - // nt = NestTest.ABTuple(a=("something", "something_else"), - // b="yet another thing") - // rev_nt = nest.map_structure(lambda x: x[::- 1], nt) - // # Check the output is the correct structure, and all strings are reversed. - // nest.assert_same_structure(nt, rev_nt) - // self.assertEqual(nt.a[0][::- 1], rev_nt.a[0]) - // self.assertEqual(nt.a[1][::- 1], rev_nt.a[1]) - // self.assertEqual(nt.b[::- 1], rev_nt.b) - - // @test_util.run_deprecated_v1 - // def testMapStructureOverPlaceholders(self) : - // inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), - // array_ops.placeholder(dtypes.float32, shape=[3, 7])) - // inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), - // array_ops.placeholder(dtypes.float32, shape=[3, 7])) - - // output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b) - - // nest.assert_same_structure(output, inp_a) - // self.assertShapeEqual(np.zeros((3, 4)), output[0]) - // self.assertShapeEqual(np.zeros((3, 7)), output[1]) - - // feed_dict = { - // inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)), - // inp_b: (np.random.randn(3, 4), np.random.randn(3, 7)) - // } - - // with self.cached_session() as sess: - // output_np = sess.run(output, feed_dict=feed_dict) - // self.assertAllClose(output_np[0], - // feed_dict[inp_a][0] + feed_dict[inp_b][0]) - // self.assertAllClose(output_np[1], - // feed_dict[inp_a][1] + feed_dict[inp_b][1]) - - // def testAssertShallowStructure(self): - // inp_ab = ["a", "b"] - //inp_abc = ["a", "b", "c"] - //expected_message = ( - // "The two structures don't have the same sequence length. Input " - // "structure has length 2, while shallow structure has length 3.") - // with self.assertRaisesRegexp(ValueError, expected_message): - // nest.assert_shallow_structure(inp_abc, inp_ab) - - // inp_ab1 = [(1, 1), (2, 2)] - // inp_ab2 = [[1, 1], [2, 2]] - // expected_message = ( - // "The two structures don't have the same sequence type. Input structure " - // "has type <(type|class) 'tuple'>, while shallow structure has type " - // "<(type|class) 'list'>.") - // with self.assertRaisesRegexp(TypeError, expected_message): - // nest.assert_shallow_structure(inp_ab2, inp_ab1) - // nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types= False) - - // inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}} - // inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}} - // expected_message = ( - // r"The two structures don't have the same keys. Input " - // r"structure has keys \['c'\], while shallow structure has " - // r"keys \['d'\].") - - // with self.assertRaisesRegexp(ValueError, expected_message): - // nest.assert_shallow_structure(inp_ab2, inp_ab1) - - // inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))]) - // inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)]) - // nest.assert_shallow_structure(inp_ab, inp_ba) - - // # This assertion is expected to pass: two namedtuples with the same - //# name and field names are considered to be identical. - //inp_shallow = NestTest.SameNameab(1, 2) - // inp_deep = NestTest.SameNameab2(1, [1, 2, 3]) - // nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False) - // nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True) - - // def testFlattenUpTo(self): - // # Shallow tree ends at scalar. - // input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] - // shallow_tree = [[True, True], [False, True]] - // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) - // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) - // self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) - // self.assertEqual(flattened_shallow_tree, [True, True, False, True]) - - //# Shallow tree ends at string. - // input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] - // shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] - // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, - // input_tree) - // input_tree_flattened = nest.flatten(input_tree) - // self.assertEqual(input_tree_flattened_as_shallow_tree, - // [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) - // self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) - - // # Make sure dicts are correctly flattened, yielding values, not keys. - //input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} - // shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} - // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, - // input_tree) - // self.assertEqual(input_tree_flattened_as_shallow_tree, - // [1, { "c": 2}, 3, (4, 5)]) - - // # Namedtuples. - // ab_tuple = NestTest.ABTuple - // input_tree = ab_tuple(a =[0, 1], b = 2) - // shallow_tree = ab_tuple(a= 0, b= 1) - // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, - // input_tree) - // self.assertEqual(input_tree_flattened_as_shallow_tree, - // [[0, 1], 2]) - - // # Nested dicts, OrderedDicts and namedtuples. - // input_tree = collections.OrderedDict( - // [("a", ab_tuple(a =[0, {"b": 1}], b=2)), - // ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})]) - // shallow_tree = input_tree - // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, - // input_tree) - // self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) - // shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})]) - // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, - // input_tree) - // self.assertEqual(input_tree_flattened_as_shallow_tree, - // [ab_tuple(a =[0, { "b": 1}], b=2), - // 3, - // collections.OrderedDict([("f", 4)])]) - // shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) - // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, - // input_tree) - // self.assertEqual(input_tree_flattened_as_shallow_tree, - // [ab_tuple(a =[0, {"b": 1}], b=2), - // {"d": 3, "e": collections.OrderedDict([("f", 4)])}]) - - // ## Shallow non-list edge-case. - // # Using iterable elements. - // input_tree = ["input_tree"] - //shallow_tree = "shallow_tree" - // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) - // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) - // self.assertEqual(flattened_input_tree, [input_tree]) - // self.assertEqual(flattened_shallow_tree, [shallow_tree]) - - // input_tree = ["input_tree_0", "input_tree_1"] - //shallow_tree = "shallow_tree" - // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) - // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) - // self.assertEqual(flattened_input_tree, [input_tree]) - // self.assertEqual(flattened_shallow_tree, [shallow_tree]) - - // # Using non-iterable elements. - //input_tree = [0] - //shallow_tree = 9 - // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) - // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) - // self.assertEqual(flattened_input_tree, [input_tree]) - // self.assertEqual(flattened_shallow_tree, [shallow_tree]) - - // input_tree = [0, 1] - //shallow_tree = 9 - // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) - // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) - // self.assertEqual(flattened_input_tree, [input_tree]) - // self.assertEqual(flattened_shallow_tree, [shallow_tree]) - - // ## Both non-list edge-case. - //# Using iterable elements. - //input_tree = "input_tree" - // shallow_tree = "shallow_tree" - // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) - // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) - // self.assertEqual(flattened_input_tree, [input_tree]) - // self.assertEqual(flattened_shallow_tree, [shallow_tree]) - - // # Using non-iterable elements. - //input_tree = 0 - // shallow_tree = 0 - // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) - // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) - // self.assertEqual(flattened_input_tree, [input_tree]) - // self.assertEqual(flattened_shallow_tree, [shallow_tree]) - - // ## Input non-list edge-case. - //# Using iterable elements. - //input_tree = "input_tree" - // shallow_tree = ["shallow_tree"] - //expected_message = ("If shallow structure is a sequence, input must also " - // "be a sequence. Input has type: <(type|class) 'str'>.") - // with self.assertRaisesRegexp(TypeError, expected_message): - // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) - // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) - // self.assertEqual(flattened_shallow_tree, shallow_tree) - - // input_tree = "input_tree" - // shallow_tree = ["shallow_tree_9", "shallow_tree_8"] - //with self.assertRaisesRegexp(TypeError, expected_message): - // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) - // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) - // self.assertEqual(flattened_shallow_tree, shallow_tree) - - //# Using non-iterable elements. - // input_tree = 0 - // shallow_tree = [9] - //expected_message = ("If shallow structure is a sequence, input must also " - // "be a sequence. Input has type: <(type|class) 'int'>.") - // with self.assertRaisesRegexp(TypeError, expected_message): - // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) - // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) - // self.assertEqual(flattened_shallow_tree, shallow_tree) - - // input_tree = 0 - // shallow_tree = [9, 8] - //with self.assertRaisesRegexp(TypeError, expected_message): - // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) - // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) - // self.assertEqual(flattened_shallow_tree, shallow_tree) - - // def testMapStructureUpTo(self) : - // # Named tuples. - // ab_tuple = collections.namedtuple("ab_tuple", "a, b") - // op_tuple = collections.namedtuple("op_tuple", "add, mul") - // inp_val = ab_tuple(a= 2, b= 3) - // inp_ops = ab_tuple(a= op_tuple(add = 1, mul = 2), b= op_tuple(add = 2, mul = 3)) - // out = nest.map_structure_up_to( - // inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops) - // self.assertEqual(out.a, 6) - // self.assertEqual(out.b, 15) - - // # Lists. - // data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] - // name_list = ["evens", ["odds", "primes"]] - // out = nest.map_structure_up_to( - // name_list, lambda name, sec: "first_{}_{}".format(len(sec), name), - // name_list, data_list) - // self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]]) - - // # Dicts. - // inp_val = dict(a= 2, b= 3) - // inp_ops = dict(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3)) - // out = nest.map_structure_up_to( - // inp_val, - // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) - // self.assertEqual(out["a"], 6) - // self.assertEqual(out["b"], 15) - - // # Non-equal dicts. - // inp_val = dict(a= 2, b= 3) - // inp_ops = dict(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3)) - // with self.assertRaisesRegexp(ValueError, "same keys"): - // nest.map_structure_up_to( - // inp_val, - // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) - - // # Dict+custom mapping. - // inp_val = dict(a= 2, b= 3) - // inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3)) - // out = nest.map_structure_up_to( - // inp_val, - // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) - // self.assertEqual(out["a"], 6) - // self.assertEqual(out["b"], 15) - - // # Non-equal dict/mapping. - // inp_val = dict(a= 2, b= 3) - // inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3)) - // with self.assertRaisesRegexp(ValueError, "same keys"): - // nest.map_structure_up_to( - // inp_val, - // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) - - // def testGetTraverseShallowStructure(self): - // scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []] - // scalar_traverse_r = nest.get_traverse_shallow_structure( - // lambda s: not isinstance(s, tuple), - // scalar_traverse_input) - // self.assertEqual(scalar_traverse_r, - // [True, True, False, [True, True], {"a": False}, []]) - // nest.assert_shallow_structure(scalar_traverse_r, - // scalar_traverse_input) - - // structure_traverse_input = [(1, [2]), ([1], 2)] - // structure_traverse_r = nest.get_traverse_shallow_structure( - // lambda s: (True, False) if isinstance(s, tuple) else True, - // structure_traverse_input) - // self.assertEqual(structure_traverse_r, - // [(True, False), ([True], False)]) - // nest.assert_shallow_structure(structure_traverse_r, - // structure_traverse_input) - - // with self.assertRaisesRegexp(TypeError, "returned structure"): - // nest.get_traverse_shallow_structure(lambda _: [True], 0) - - // with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"): - // nest.get_traverse_shallow_structure(lambda _: 1, [1]) - - // with self.assertRaisesRegexp( - // TypeError, "didn't return a depth=1 structure of bools"): - // nest.get_traverse_shallow_structure(lambda _: [1], [1]) - - // def testYieldFlatStringPaths(self): - // for inputs_expected in ({"inputs": [], "expected": []}, - // {"inputs": 3, "expected": [()]}, - // {"inputs": [3], "expected": [(0,)]}, - // {"inputs": {"a": 3}, "expected": [("a",)]}, - // {"inputs": {"a": {"b": 4}}, - // "expected": [("a", "b")]}, - // {"inputs": [{"a": 2}], "expected": [(0, "a")]}, - // {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]}, - // {"inputs": [{"a": [(23, 42)]}], - // "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]}, - // {"inputs": [{"a": ([23], 42)}], - // "expected": [(0, "a", 0, 0), (0, "a", 1)]}, - // {"inputs": {"a": {"a": 2}, "c": [[[4]]]}, - // "expected": [("a", "a"), ("c", 0, 0, 0)]}, - // {"inputs": {"0": [{"1": 23}]}, - // "expected": [("0", 0, "1")]}): - // inputs = inputs_expected["inputs"] - // expected = inputs_expected["expected"] - // self.assertEqual(list(nest.yield_flat_paths(inputs)), expected) - - // def testFlattenWithStringPaths(self): - // for inputs_expected in ( - // {"inputs": [], "expected": []}, - // {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]}, - // {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}): - // inputs = inputs_expected["inputs"] - // expected = inputs_expected["expected"] - // self.assertEqual( - // nest.flatten_with_joined_string_paths(inputs, separator="/"), - // expected) - - // # Need a separate test for namedtuple as we can't declare tuple definitions - // # in the @parameterized arguments. - // def testFlattenNamedTuple(self): - // # pylint: disable=invalid-name - // Foo = collections.namedtuple("Foo", ["a", "b"]) - // Bar = collections.namedtuple("Bar", ["c", "d"]) - // # pylint: enable=invalid-name - // test_cases = [ - // (Foo(a = 3, b = Bar(c = 23, d = 42)), - // [("a", 3), ("b/c", 23), ("b/d", 42)]), - // (Foo(a = Bar(c = 23, d = 42), b = Bar(c = 0, d = "something")), - // [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]), - // (Bar(c = 42, d = 43), - // [("c", 42), ("d", 43)]), - // (Bar(c =[42], d = 43), - // [("c/0", 42), ("d", 43)]), - // ] - // for inputs, expected in test_cases: - // self.assertEqual( - // list(nest.flatten_with_joined_string_paths(inputs)), expected) - - // @parameterized.named_parameters( - // ("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))), - // ("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True, - // {"a": ("a", 4), "b": ("b", 6)}), - // ("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))), - // ("nested", - // {"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True, - // {"a": [("a/0", 10), ("a/1", 12)], - // "b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]})) - // def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected): - // def format_sum(path, * values): - // return (path, sum(values)) - // result = nest.map_structure_with_paths(format_sum, s1, s2, - // check_types=check_types) - // self.assertEqual(expected, result) - - // @parameterized.named_parameters( - // ("tuples", (1, 2), (3, 4, 5), ValueError), - // ("dicts", {"a": 1}, {"b": 2}, ValueError), - // ("mixed", (1, 2), [3, 4], TypeError), - // ("nested", - // {"a": [2, 3], "b": [1, 3]}, - // {"b": [5, 6, 7], "a": [8, 9]}, - // ValueError - // )) - // def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type): - // with self.assertRaises(error_type): - // nest.map_structure_with_paths(lambda path, * s: 0, s1, s2) - - - //class NestBenchmark(test.Benchmark): - - // def run_and_report(self, s1, s2, name): - // burn_iter, test_iter = 100, 30000 - - // for _ in xrange(burn_iter) : - // nest.assert_same_structure(s1, s2) - - // t0 = time.time() - // for _ in xrange(test_iter) : - // nest.assert_same_structure(s1, s2) - // t1 = time.time() - - // self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter, - // name=name) - - // def benchmark_assert_structure(self): - // s1 = (((1, 2), 3), 4, (5, 6)) - // s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) - // self.run_and_report(s1, s2, "assert_same_structure_6_elem") - - // s1 = (((1, 2), 3), 4, (5, 6)) * 10 - // s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10 - // self.run_and_report(s1, s2, "assert_same_structure_60_elem") - - - //if __name__ == "__main__": - // test.main() - } -} diff --git a/test/TensorFlowNET.UnitTest/nest_test/nest_test.py b/test/TensorFlowNET.UnitTest/nest_test/nest_test.py deleted file mode 100644 index d0d0c5f7..00000000 --- a/test/TensorFlowNET.UnitTest/nest_test/nest_test.py +++ /dev/null @@ -1,883 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for utilities working with arbitrarily nested structures.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections -import time - -from absl.testing import parameterized -import numpy as np -from six.moves import xrange # pylint: disable=redefined-builtin - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import test_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.platform import test -from tensorflow.python.util import nest - -try: - import attr # pylint:disable=g-import-not-at-top -except ImportError: - attr = None - - -class _CustomMapping(collections.Mapping): - - def __init__(self, *args, **kwargs): - self._wrapped = dict(*args, **kwargs) - - def __getitem__(self, key): - return self._wrapped[key] - - def __iter__(self): - return iter(self._wrapped) - - def __len__(self): - return len(self._wrapped) - - -class NestTest(parameterized.TestCase, test.TestCase): - - PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name - - if attr: - class BadAttr(object): - """Class that has a non-iterable __attrs_attrs__.""" - __attrs_attrs__ = None - - @attr.s - class SampleAttr(object): - field1 = attr.ib() - field2 = attr.ib() - - @test_util.assert_no_new_pyobjects_executing_eagerly - def testAttrsFlattenAndPack(self): - if attr is None: - self.skipTest("attr module is unavailable.") - - field_values = [1, 2] - sample_attr = NestTest.SampleAttr(*field_values) - self.assertFalse(nest._is_attrs(field_values)) - self.assertTrue(nest._is_attrs(sample_attr)) - flat = nest.flatten(sample_attr) - self.assertEqual(field_values, flat) - restructured_from_flat = nest.pack_sequence_as(sample_attr, flat) - self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr) - self.assertEqual(restructured_from_flat, sample_attr) - - # Check that flatten fails if attributes are not iterable - with self.assertRaisesRegexp(TypeError, "object is not iterable"): - flat = nest.flatten(NestTest.BadAttr()) - - @test_util.assert_no_new_pyobjects_executing_eagerly - def testFlattenAndPack(self): - structure = ((3, 4), 5, (6, 7, (9, 10), 8)) - flat = ["a", "b", "c", "d", "e", "f", "g", "h"] - self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8]) - self.assertEqual( - nest.pack_sequence_as(structure, flat), (("a", "b"), "c", - ("d", "e", ("f", "g"), "h"))) - structure = (NestTest.PointXY(x=4, y=2), - ((NestTest.PointXY(x=1, y=0),),)) - flat = [4, 2, 1, 0] - self.assertEqual(nest.flatten(structure), flat) - restructured_from_flat = nest.pack_sequence_as(structure, flat) - self.assertEqual(restructured_from_flat, structure) - self.assertEqual(restructured_from_flat[0].x, 4) - self.assertEqual(restructured_from_flat[0].y, 2) - self.assertEqual(restructured_from_flat[1][0][0].x, 1) - self.assertEqual(restructured_from_flat[1][0][0].y, 0) - - self.assertEqual([5], nest.flatten(5)) - self.assertEqual([np.array([5])], nest.flatten(np.array([5]))) - - self.assertEqual("a", nest.pack_sequence_as(5, ["a"])) - self.assertEqual( - np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])])) - - with self.assertRaisesRegexp(ValueError, "Structure is a scalar"): - nest.pack_sequence_as("scalar", [4, 5]) - - with self.assertRaisesRegexp(TypeError, "flat_sequence"): - nest.pack_sequence_as([4, 5], "bad_sequence") - - with self.assertRaises(ValueError): - nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"]) - - @parameterized.parameters({"mapping_type": collections.OrderedDict}, - {"mapping_type": _CustomMapping}) - @test_util.assert_no_new_pyobjects_executing_eagerly - def testFlattenDictOrder(self, mapping_type): - """`flatten` orders dicts by key, including OrderedDicts.""" - ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) - plain = {"d": 3, "b": 1, "a": 0, "c": 2} - ordered_flat = nest.flatten(ordered) - plain_flat = nest.flatten(plain) - self.assertEqual([0, 1, 2, 3], ordered_flat) - self.assertEqual([0, 1, 2, 3], plain_flat) - - @parameterized.parameters({"mapping_type": collections.OrderedDict}, - {"mapping_type": _CustomMapping}) - def testPackDictOrder(self, mapping_type): - """Packing orders dicts by key, including OrderedDicts.""" - custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)]) - plain = {"d": 0, "b": 0, "a": 0, "c": 0} - seq = [0, 1, 2, 3] - custom_reconstruction = nest.pack_sequence_as(custom, seq) - plain_reconstruction = nest.pack_sequence_as(plain, seq) - self.assertIsInstance(custom_reconstruction, mapping_type) - self.assertIsInstance(plain_reconstruction, dict) - self.assertEqual( - mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]), - custom_reconstruction) - self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction) - - Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name - - @test_util.assert_no_new_pyobjects_executing_eagerly - def testFlattenAndPack_withDicts(self): - # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. - mess = [ - "z", - NestTest.Abc(3, 4), { - "d": _CustomMapping({ - 41: 4 - }), - "c": [ - 1, - collections.OrderedDict([ - ("b", 3), - ("a", 2), - ]), - ], - "b": 5 - }, 17 - ] - - flattened = nest.flatten(mess) - self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17]) - - structure_of_mess = [ - 14, - NestTest.Abc("a", True), - { - "d": _CustomMapping({ - 41: 42 - }), - "c": [ - 0, - collections.OrderedDict([ - ("b", 9), - ("a", 8), - ]), - ], - "b": 3 - }, - "hi everybody", - ] - - unflattened = nest.pack_sequence_as(structure_of_mess, flattened) - self.assertEqual(unflattened, mess) - - # Check also that the OrderedDict was created, with the correct key order. - unflattened_ordered_dict = unflattened[2]["c"][1] - self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict) - self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"]) - - unflattened_custom_mapping = unflattened[2]["d"] - self.assertIsInstance(unflattened_custom_mapping, _CustomMapping) - self.assertEqual(list(unflattened_custom_mapping.keys()), [41]) - - def testFlatten_numpyIsNotFlattened(self): - structure = np.array([1, 2, 3]) - flattened = nest.flatten(structure) - self.assertEqual(len(flattened), 1) - - def testFlatten_stringIsNotFlattened(self): - structure = "lots of letters" - flattened = nest.flatten(structure) - self.assertEqual(len(flattened), 1) - unflattened = nest.pack_sequence_as("goodbye", flattened) - self.assertEqual(structure, unflattened) - - def testPackSequenceAs_notIterableError(self): - with self.assertRaisesRegexp(TypeError, - "flat_sequence must be a sequence"): - nest.pack_sequence_as("hi", "bye") - - def testPackSequenceAs_wrongLengthsError(self): - with self.assertRaisesRegexp( - ValueError, - "Structure had 2 elements, but flat_sequence had 3 elements."): - nest.pack_sequence_as(["hello", "world"], - ["and", "goodbye", "again"]) - - @test_util.assert_no_new_pyobjects_executing_eagerly - def testIsSequence(self): - self.assertFalse(nest.is_sequence("1234")) - self.assertTrue(nest.is_sequence([1, 3, [4, 5]])) - self.assertTrue(nest.is_sequence(((7, 8), (5, 6)))) - self.assertTrue(nest.is_sequence([])) - self.assertTrue(nest.is_sequence({"a": 1, "b": 2})) - self.assertFalse(nest.is_sequence(set([1, 2]))) - ones = array_ops.ones([2, 3]) - self.assertFalse(nest.is_sequence(ones)) - self.assertFalse(nest.is_sequence(math_ops.tanh(ones))) - self.assertFalse(nest.is_sequence(np.ones((4, 5)))) - - @parameterized.parameters({"mapping_type": _CustomMapping}, - {"mapping_type": dict}) - def testFlattenDictItems(self, mapping_type): - dictionary = mapping_type({(4, 5, (6, 8)): ("a", "b", ("c", "d"))}) - flat = {4: "a", 5: "b", 6: "c", 8: "d"} - self.assertEqual(nest.flatten_dict_items(dictionary), flat) - - with self.assertRaises(TypeError): - nest.flatten_dict_items(4) - - bad_dictionary = mapping_type({(4, 5, (4, 8)): ("a", "b", ("c", "d"))}) - with self.assertRaisesRegexp(ValueError, "not unique"): - nest.flatten_dict_items(bad_dictionary) - - another_bad_dictionary = mapping_type({ - (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e"))) - }) - with self.assertRaisesRegexp( - ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"): - nest.flatten_dict_items(another_bad_dictionary) - - # pylint does not correctly recognize these as class names and - # suggests to use variable style under_score naming. - # pylint: disable=invalid-name - Named0ab = collections.namedtuple("named_0", ("a", "b")) - Named1ab = collections.namedtuple("named_1", ("a", "b")) - SameNameab = collections.namedtuple("same_name", ("a", "b")) - SameNameab2 = collections.namedtuple("same_name", ("a", "b")) - SameNamexy = collections.namedtuple("same_name", ("x", "y")) - SameName1xy = collections.namedtuple("same_name_1", ("x", "y")) - SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y")) - NotSameName = collections.namedtuple("not_same_name", ("a", "b")) - # pylint: enable=invalid-name - - class SameNamedType1(SameNameab): - pass - - @test_util.assert_no_new_pyobjects_executing_eagerly - def testAssertSameStructure(self): - structure1 = (((1, 2), 3), 4, (5, 6)) - structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) - structure_different_num_elements = ("spam", "eggs") - structure_different_nesting = (((1, 2), 3), 4, 5, (6,)) - nest.assert_same_structure(structure1, structure2) - nest.assert_same_structure("abc", 1.0) - nest.assert_same_structure("abc", np.array([0, 1])) - nest.assert_same_structure("abc", constant_op.constant([0, 1])) - - with self.assertRaisesRegexp( - ValueError, - ("The two structures don't have the same nested structure\\.\n\n" - "First structure:.*?\n\n" - "Second structure:.*\n\n" - "More specifically: Substructure " - r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while ' - 'substructure "type=str str=spam" is not\n' - "Entire first structure:\n" - r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n" - "Entire second structure:\n" - r"\(\., \.\)")): - nest.assert_same_structure(structure1, structure_different_num_elements) - - with self.assertRaisesRegexp( - ValueError, - ("The two structures don't have the same nested structure\\.\n\n" - "First structure:.*?\n\n" - "Second structure:.*\n\n" - r'More specifically: Substructure "type=list str=\[0, 1\]" ' - r'is a sequence, while substructure "type=ndarray str=\[0 1\]" ' - "is not")): - nest.assert_same_structure([0, 1], np.array([0, 1])) - - with self.assertRaisesRegexp( - ValueError, - ("The two structures don't have the same nested structure\\.\n\n" - "First structure:.*?\n\n" - "Second structure:.*\n\n" - r'More specifically: Substructure "type=list str=\[0, 1\]" ' - 'is a sequence, while substructure "type=int str=0" ' - "is not")): - nest.assert_same_structure(0, [0, 1]) - - self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1]) - - with self.assertRaisesRegexp( - ValueError, - ("don't have the same nested structure\\.\n\n" - "First structure: .*?\n\nSecond structure: ")): - nest.assert_same_structure(structure1, structure_different_nesting) - - self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), - NestTest.Named0ab("a", "b")) - - nest.assert_same_structure(NestTest.Named0ab(3, 4), - NestTest.Named0ab("a", "b")) - - self.assertRaises(TypeError, nest.assert_same_structure, - NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4)) - - with self.assertRaisesRegexp( - ValueError, - ("don't have the same nested structure\\.\n\n" - "First structure: .*?\n\nSecond structure: ")): - nest.assert_same_structure(NestTest.Named0ab(3, 4), - NestTest.Named0ab([3], 4)) - - with self.assertRaisesRegexp( - ValueError, - ("don't have the same nested structure\\.\n\n" - "First structure: .*?\n\nSecond structure: ")): - nest.assert_same_structure([[3], 4], [3, [4]]) - - structure1_list = [[[1, 2], 3], 4, [5, 6]] - with self.assertRaisesRegexp(TypeError, - "don't have the same sequence type"): - nest.assert_same_structure(structure1, structure1_list) - nest.assert_same_structure(structure1, structure2, check_types=False) - nest.assert_same_structure(structure1, structure1_list, check_types=False) - - with self.assertRaisesRegexp(ValueError, - "don't have the same set of keys"): - nest.assert_same_structure({"a": 1}, {"b": 1}) - - nest.assert_same_structure(NestTest.SameNameab(0, 1), - NestTest.SameNameab2(2, 3)) - - # This assertion is expected to pass: two namedtuples with the same - # name and field names are considered to be identical. - nest.assert_same_structure( - NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2), - NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4)) - - expected_message = "The two structures don't have the same.*" - with self.assertRaisesRegexp(ValueError, expected_message): - nest.assert_same_structure( - NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)), - NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2)) - - self.assertRaises(TypeError, nest.assert_same_structure, - NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3)) - - self.assertRaises(TypeError, nest.assert_same_structure, - NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3)) - - self.assertRaises(TypeError, nest.assert_same_structure, - NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3)) - - EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name - - def testHeterogeneousComparison(self): - nest.assert_same_structure({"a": 4}, _CustomMapping(a=3)) - nest.assert_same_structure(_CustomMapping(b=3), {"b": 4}) - - @test_util.assert_no_new_pyobjects_executing_eagerly - def testMapStructure(self): - structure1 = (((1, 2), 3), 4, (5, 6)) - structure2 = (((7, 8), 9), 10, (11, 12)) - structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1) - nest.assert_same_structure(structure1, structure1_plus1) - self.assertAllEqual( - [2, 3, 4, 5, 6, 7], - nest.flatten(structure1_plus1)) - structure1_plus_structure2 = nest.map_structure( - lambda x, y: x + y, structure1, structure2) - self.assertEqual( - (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)), - structure1_plus_structure2) - - self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4)) - - self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4)) - - # Empty structures - self.assertEqual((), nest.map_structure(lambda x: x + 1, ())) - self.assertEqual([], nest.map_structure(lambda x: x + 1, [])) - self.assertEqual({}, nest.map_structure(lambda x: x + 1, {})) - self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1, - NestTest.EmptyNT())) - - # This is checking actual equality of types, empty list != empty tuple - self.assertNotEqual((), nest.map_structure(lambda x: x + 1, [])) - - with self.assertRaisesRegexp(TypeError, "callable"): - nest.map_structure("bad", structure1_plus1) - - with self.assertRaisesRegexp(ValueError, "at least one structure"): - nest.map_structure(lambda x: x) - - with self.assertRaisesRegexp(ValueError, "same number of elements"): - nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5)) - - with self.assertRaisesRegexp(ValueError, "same nested structure"): - nest.map_structure(lambda x, y: None, 3, (3,)) - - with self.assertRaisesRegexp(TypeError, "same sequence type"): - nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5]) - - with self.assertRaisesRegexp(ValueError, "same nested structure"): - nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5))) - - structure1_list = [[[1, 2], 3], 4, [5, 6]] - with self.assertRaisesRegexp(TypeError, "same sequence type"): - nest.map_structure(lambda x, y: None, structure1, structure1_list) - - nest.map_structure(lambda x, y: None, structure1, structure1_list, - check_types=False) - - with self.assertRaisesRegexp(ValueError, "same nested structure"): - nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)), - check_types=False) - - with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): - nest.map_structure(lambda x: None, structure1, foo="a") - - with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): - nest.map_structure(lambda x: None, structure1, check_types=False, foo="a") - - ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name - - @test_util.assert_no_new_pyobjects_executing_eagerly - def testMapStructureWithStrings(self): - inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz")) - inp_b = NestTest.ABTuple(a=2, b=(1, 3)) - out = nest.map_structure(lambda string, repeats: string * repeats, - inp_a, - inp_b) - self.assertEqual("foofoo", out.a) - self.assertEqual("bar", out.b[0]) - self.assertEqual("bazbazbaz", out.b[1]) - - nt = NestTest.ABTuple(a=("something", "something_else"), - b="yet another thing") - rev_nt = nest.map_structure(lambda x: x[::-1], nt) - # Check the output is the correct structure, and all strings are reversed. - nest.assert_same_structure(nt, rev_nt) - self.assertEqual(nt.a[0][::-1], rev_nt.a[0]) - self.assertEqual(nt.a[1][::-1], rev_nt.a[1]) - self.assertEqual(nt.b[::-1], rev_nt.b) - - @test_util.run_deprecated_v1 - def testMapStructureOverPlaceholders(self): - inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), - array_ops.placeholder(dtypes.float32, shape=[3, 7])) - inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), - array_ops.placeholder(dtypes.float32, shape=[3, 7])) - - output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b) - - nest.assert_same_structure(output, inp_a) - self.assertShapeEqual(np.zeros((3, 4)), output[0]) - self.assertShapeEqual(np.zeros((3, 7)), output[1]) - - feed_dict = { - inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)), - inp_b: (np.random.randn(3, 4), np.random.randn(3, 7)) - } - - with self.cached_session() as sess: - output_np = sess.run(output, feed_dict=feed_dict) - self.assertAllClose(output_np[0], - feed_dict[inp_a][0] + feed_dict[inp_b][0]) - self.assertAllClose(output_np[1], - feed_dict[inp_a][1] + feed_dict[inp_b][1]) - - def testAssertShallowStructure(self): - inp_ab = ["a", "b"] - inp_abc = ["a", "b", "c"] - expected_message = ( - "The two structures don't have the same sequence length. Input " - "structure has length 2, while shallow structure has length 3.") - with self.assertRaisesRegexp(ValueError, expected_message): - nest.assert_shallow_structure(inp_abc, inp_ab) - - inp_ab1 = [(1, 1), (2, 2)] - inp_ab2 = [[1, 1], [2, 2]] - expected_message = ( - "The two structures don't have the same sequence type. Input structure " - "has type <(type|class) 'tuple'>, while shallow structure has type " - "<(type|class) 'list'>.") - with self.assertRaisesRegexp(TypeError, expected_message): - nest.assert_shallow_structure(inp_ab2, inp_ab1) - nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False) - - inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}} - inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}} - expected_message = ( - r"The two structures don't have the same keys. Input " - r"structure has keys \['c'\], while shallow structure has " - r"keys \['d'\].") - - with self.assertRaisesRegexp(ValueError, expected_message): - nest.assert_shallow_structure(inp_ab2, inp_ab1) - - inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))]) - inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)]) - nest.assert_shallow_structure(inp_ab, inp_ba) - - # This assertion is expected to pass: two namedtuples with the same - # name and field names are considered to be identical. - inp_shallow = NestTest.SameNameab(1, 2) - inp_deep = NestTest.SameNameab2(1, [1, 2, 3]) - nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False) - nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True) - - def testFlattenUpTo(self): - # Shallow tree ends at scalar. - input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] - shallow_tree = [[True, True], [False, True]] - flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) - flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) - self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) - self.assertEqual(flattened_shallow_tree, [True, True, False, True]) - - # Shallow tree ends at string. - input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] - shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] - input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, - input_tree) - input_tree_flattened = nest.flatten(input_tree) - self.assertEqual(input_tree_flattened_as_shallow_tree, - [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) - self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) - - # Make sure dicts are correctly flattened, yielding values, not keys. - input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} - shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} - input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, - input_tree) - self.assertEqual(input_tree_flattened_as_shallow_tree, - [1, {"c": 2}, 3, (4, 5)]) - - # Namedtuples. - ab_tuple = NestTest.ABTuple - input_tree = ab_tuple(a=[0, 1], b=2) - shallow_tree = ab_tuple(a=0, b=1) - input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, - input_tree) - self.assertEqual(input_tree_flattened_as_shallow_tree, - [[0, 1], 2]) - - # Nested dicts, OrderedDicts and namedtuples. - input_tree = collections.OrderedDict( - [("a", ab_tuple(a=[0, {"b": 1}], b=2)), - ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})]) - shallow_tree = input_tree - input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, - input_tree) - self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) - shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})]) - input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, - input_tree) - self.assertEqual(input_tree_flattened_as_shallow_tree, - [ab_tuple(a=[0, {"b": 1}], b=2), - 3, - collections.OrderedDict([("f", 4)])]) - shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) - input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, - input_tree) - self.assertEqual(input_tree_flattened_as_shallow_tree, - [ab_tuple(a=[0, {"b": 1}], b=2), - {"d": 3, "e": collections.OrderedDict([("f", 4)])}]) - - ## Shallow non-list edge-case. - # Using iterable elements. - input_tree = ["input_tree"] - shallow_tree = "shallow_tree" - flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) - flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) - self.assertEqual(flattened_input_tree, [input_tree]) - self.assertEqual(flattened_shallow_tree, [shallow_tree]) - - input_tree = ["input_tree_0", "input_tree_1"] - shallow_tree = "shallow_tree" - flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) - flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) - self.assertEqual(flattened_input_tree, [input_tree]) - self.assertEqual(flattened_shallow_tree, [shallow_tree]) - - # Using non-iterable elements. - input_tree = [0] - shallow_tree = 9 - flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) - flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) - self.assertEqual(flattened_input_tree, [input_tree]) - self.assertEqual(flattened_shallow_tree, [shallow_tree]) - - input_tree = [0, 1] - shallow_tree = 9 - flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) - flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) - self.assertEqual(flattened_input_tree, [input_tree]) - self.assertEqual(flattened_shallow_tree, [shallow_tree]) - - ## Both non-list edge-case. - # Using iterable elements. - input_tree = "input_tree" - shallow_tree = "shallow_tree" - flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) - flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) - self.assertEqual(flattened_input_tree, [input_tree]) - self.assertEqual(flattened_shallow_tree, [shallow_tree]) - - # Using non-iterable elements. - input_tree = 0 - shallow_tree = 0 - flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) - flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) - self.assertEqual(flattened_input_tree, [input_tree]) - self.assertEqual(flattened_shallow_tree, [shallow_tree]) - - ## Input non-list edge-case. - # Using iterable elements. - input_tree = "input_tree" - shallow_tree = ["shallow_tree"] - expected_message = ("If shallow structure is a sequence, input must also " - "be a sequence. Input has type: <(type|class) 'str'>.") - with self.assertRaisesRegexp(TypeError, expected_message): - flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) - flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) - self.assertEqual(flattened_shallow_tree, shallow_tree) - - input_tree = "input_tree" - shallow_tree = ["shallow_tree_9", "shallow_tree_8"] - with self.assertRaisesRegexp(TypeError, expected_message): - flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) - flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) - self.assertEqual(flattened_shallow_tree, shallow_tree) - - # Using non-iterable elements. - input_tree = 0 - shallow_tree = [9] - expected_message = ("If shallow structure is a sequence, input must also " - "be a sequence. Input has type: <(type|class) 'int'>.") - with self.assertRaisesRegexp(TypeError, expected_message): - flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) - flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) - self.assertEqual(flattened_shallow_tree, shallow_tree) - - input_tree = 0 - shallow_tree = [9, 8] - with self.assertRaisesRegexp(TypeError, expected_message): - flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) - flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) - self.assertEqual(flattened_shallow_tree, shallow_tree) - - def testMapStructureUpTo(self): - # Named tuples. - ab_tuple = collections.namedtuple("ab_tuple", "a, b") - op_tuple = collections.namedtuple("op_tuple", "add, mul") - inp_val = ab_tuple(a=2, b=3) - inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) - out = nest.map_structure_up_to( - inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops) - self.assertEqual(out.a, 6) - self.assertEqual(out.b, 15) - - # Lists. - data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] - name_list = ["evens", ["odds", "primes"]] - out = nest.map_structure_up_to( - name_list, lambda name, sec: "first_{}_{}".format(len(sec), name), - name_list, data_list) - self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]]) - - # Dicts. - inp_val = dict(a=2, b=3) - inp_ops = dict(a=dict(add=1, mul=2), b=dict(add=2, mul=3)) - out = nest.map_structure_up_to( - inp_val, - lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) - self.assertEqual(out["a"], 6) - self.assertEqual(out["b"], 15) - - # Non-equal dicts. - inp_val = dict(a=2, b=3) - inp_ops = dict(a=dict(add=1, mul=2), c=dict(add=2, mul=3)) - with self.assertRaisesRegexp(ValueError, "same keys"): - nest.map_structure_up_to( - inp_val, - lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) - - # Dict+custom mapping. - inp_val = dict(a=2, b=3) - inp_ops = _CustomMapping(a=dict(add=1, mul=2), b=dict(add=2, mul=3)) - out = nest.map_structure_up_to( - inp_val, - lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) - self.assertEqual(out["a"], 6) - self.assertEqual(out["b"], 15) - - # Non-equal dict/mapping. - inp_val = dict(a=2, b=3) - inp_ops = _CustomMapping(a=dict(add=1, mul=2), c=dict(add=2, mul=3)) - with self.assertRaisesRegexp(ValueError, "same keys"): - nest.map_structure_up_to( - inp_val, - lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) - - def testGetTraverseShallowStructure(self): - scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []] - scalar_traverse_r = nest.get_traverse_shallow_structure( - lambda s: not isinstance(s, tuple), - scalar_traverse_input) - self.assertEqual(scalar_traverse_r, - [True, True, False, [True, True], {"a": False}, []]) - nest.assert_shallow_structure(scalar_traverse_r, - scalar_traverse_input) - - structure_traverse_input = [(1, [2]), ([1], 2)] - structure_traverse_r = nest.get_traverse_shallow_structure( - lambda s: (True, False) if isinstance(s, tuple) else True, - structure_traverse_input) - self.assertEqual(structure_traverse_r, - [(True, False), ([True], False)]) - nest.assert_shallow_structure(structure_traverse_r, - structure_traverse_input) - - with self.assertRaisesRegexp(TypeError, "returned structure"): - nest.get_traverse_shallow_structure(lambda _: [True], 0) - - with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"): - nest.get_traverse_shallow_structure(lambda _: 1, [1]) - - with self.assertRaisesRegexp( - TypeError, "didn't return a depth=1 structure of bools"): - nest.get_traverse_shallow_structure(lambda _: [1], [1]) - - def testYieldFlatStringPaths(self): - for inputs_expected in ({"inputs": [], "expected": []}, - {"inputs": 3, "expected": [()]}, - {"inputs": [3], "expected": [(0,)]}, - {"inputs": {"a": 3}, "expected": [("a",)]}, - {"inputs": {"a": {"b": 4}}, - "expected": [("a", "b")]}, - {"inputs": [{"a": 2}], "expected": [(0, "a")]}, - {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]}, - {"inputs": [{"a": [(23, 42)]}], - "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]}, - {"inputs": [{"a": ([23], 42)}], - "expected": [(0, "a", 0, 0), (0, "a", 1)]}, - {"inputs": {"a": {"a": 2}, "c": [[[4]]]}, - "expected": [("a", "a"), ("c", 0, 0, 0)]}, - {"inputs": {"0": [{"1": 23}]}, - "expected": [("0", 0, "1")]}): - inputs = inputs_expected["inputs"] - expected = inputs_expected["expected"] - self.assertEqual(list(nest.yield_flat_paths(inputs)), expected) - - def testFlattenWithStringPaths(self): - for inputs_expected in ( - {"inputs": [], "expected": []}, - {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]}, - {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}): - inputs = inputs_expected["inputs"] - expected = inputs_expected["expected"] - self.assertEqual( - nest.flatten_with_joined_string_paths(inputs, separator="/"), - expected) - - # Need a separate test for namedtuple as we can't declare tuple definitions - # in the @parameterized arguments. - def testFlattenNamedTuple(self): - # pylint: disable=invalid-name - Foo = collections.namedtuple("Foo", ["a", "b"]) - Bar = collections.namedtuple("Bar", ["c", "d"]) - # pylint: enable=invalid-name - test_cases = [ - (Foo(a=3, b=Bar(c=23, d=42)), - [("a", 3), ("b/c", 23), ("b/d", 42)]), - (Foo(a=Bar(c=23, d=42), b=Bar(c=0, d="something")), - [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]), - (Bar(c=42, d=43), - [("c", 42), ("d", 43)]), - (Bar(c=[42], d=43), - [("c/0", 42), ("d", 43)]), - ] - for inputs, expected in test_cases: - self.assertEqual( - list(nest.flatten_with_joined_string_paths(inputs)), expected) - - @parameterized.named_parameters( - ("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))), - ("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True, - {"a": ("a", 4), "b": ("b", 6)}), - ("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))), - ("nested", - {"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True, - {"a": [("a/0", 10), ("a/1", 12)], - "b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]})) - def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected): - def format_sum(path, *values): - return (path, sum(values)) - result = nest.map_structure_with_paths(format_sum, s1, s2, - check_types=check_types) - self.assertEqual(expected, result) - - @parameterized.named_parameters( - ("tuples", (1, 2), (3, 4, 5), ValueError), - ("dicts", {"a": 1}, {"b": 2}, ValueError), - ("mixed", (1, 2), [3, 4], TypeError), - ("nested", - {"a": [2, 3], "b": [1, 3]}, - {"b": [5, 6, 7], "a": [8, 9]}, - ValueError - )) - def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type): - with self.assertRaises(error_type): - nest.map_structure_with_paths(lambda path, *s: 0, s1, s2) - - -class NestBenchmark(test.Benchmark): - - def run_and_report(self, s1, s2, name): - burn_iter, test_iter = 100, 30000 - - for _ in xrange(burn_iter): - nest.assert_same_structure(s1, s2) - - t0 = time.time() - for _ in xrange(test_iter): - nest.assert_same_structure(s1, s2) - t1 = time.time() - - self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter, - name=name) - - def benchmark_assert_structure(self): - s1 = (((1, 2), 3), 4, (5, 6)) - s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) - self.run_and_report(s1, s2, "assert_same_structure_6_elem") - - s1 = (((1, 2), 3), 4, (5, 6)) * 10 - s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10 - self.run_and_report(s1, s2, "assert_same_structure_60_elem") - - -if __name__ == "__main__": - test.main() diff --git a/test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs b/test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs deleted file mode 100644 index bf55cda8..00000000 --- a/test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs +++ /dev/null @@ -1,249 +0,0 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; -using System.Linq; -using Tensorflow; -using Tensorflow.UnitTest; -using static Tensorflow.Binding; - -namespace TensorFlowNET.UnitTest.ops_test -{ - /// - /// excerpt of tensorflow/python/framework/ops_test.py - /// - [TestClass] - public class ControlDependenciesTest : GraphModeTestBase - { - [TestMethod] - public void TestBasic() - { - var g = tf.Graph().as_default(); - Tensor a = null, b = null, c = null, d = null, e = null; - - a = constant_op.constant(1.0); - b = constant_op.constant(1.0); - tf_with(g.control_dependencies(new[] { a }), x => - { - c = constant_op.constant(1.0); - d = array_ops.identity(b); - e = array_ops.identity(c); - }); - - Assert.IsTrue(Enumerable.SequenceEqual(c.op.control_inputs, new[] { a.op })); - Assert.IsTrue(Enumerable.SequenceEqual(d.op.control_inputs, new[] { a.op })); - // e should be dominated by c. - Assert.AreEqual(0, e.op.control_inputs.Length); - } - - [Ignore("How to port the ConvertibleObj?")] - [TestMethod] - public void TestBasicWithConversion() - { - var g = tf.Graph().as_default(); - // Note: _apply_op can be replaced by g.create_op - var a = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT }); - // TODO: ConvertibleObj, see original source below - /* - def testBasicWithConversion(self): - g = ops.Graph() - a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - - class ConvertibleObj(object): - - def _as_graph_element(self): - return a - - with g.control_dependencies([ConvertibleObj()]): - c = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - - self.assertEqual(c.op.control_inputs, [a.op]) - */ - } - - [TestMethod] - public void TestNested() - { - var g = tf.Graph().as_default(); - var a_1 = constant_op.constant(1.0); - var a_2 = constant_op.constant(3.0); - var a_3 = constant_op.constant(4.0); - var a_4 = constant_op.constant(5.0); - Tensor b_1 = null, b_2 = null; - tf_with(g.control_dependencies(new[] { a_1, a_2, a_3, a_4 }), ctrl => - { - b_1 = constant_op.constant(6.0); - }); - tf_with(g.control_dependencies(new[] { a_1 }), ctrl1 => - { - tf_with(g.control_dependencies(new[] { a_2 }), ctrl2 => - { - tf_with(g.control_dependencies(new[] { a_3 }), ctrl3 => - { - tf_with(g.control_dependencies(new[] { a_4 }), ctrl4 => - { - b_2 = constant_op.constant(7.0); - }); - }); - }); - }); - //var z=tf.add(a_1, tf.multiply(b_2, b_1)); - //with(g.control_dependencies(new[] {z}), ctrl => - //{ - // var z1 = tf.add(a_3, tf.multiply(a_4, a_2)); - //}); - //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); - assertItemsEqual(b_1.op.control_inputs, new[] { a_1.op, a_2.op, a_3.op, a_4.op }); - assertItemsEqual(b_2.op.control_inputs, b_1.op.control_inputs); - } - - [TestMethod] - public void TestClear() - { - var g = tf.Graph().as_default(); - var a_1 = constant_op.constant(1.0); - var a_2 = constant_op.constant(3.0); - var a_3 = constant_op.constant(4.0); - var a_4 = constant_op.constant(5.0); - Operation b_3_4 = null, b_3 = null, b_none = null, b_1 = null, b_1_2 = null, b_none2 = null; - tf_with(g.control_dependencies(new[] { a_1 }), ctrl1 => - { - tf_with(g.control_dependencies(new[] { a_2 }), ctrl2 => - { - tf_with(g.control_dependencies(null), ctrl3 => - { - tf_with(g.control_dependencies(new[] { a_3 }), ctrl4 => - { - tf_with(g.control_dependencies(new[] { a_4 }), ctrl5 => - { - // deps [a_3, a_4] - b_3_4 = constant_op.constant(7.0); - }); - // deps = [a_3] - b_3 = constant_op.constant(8.0); - }); - // deps back to None - b_none = constant_op.constant(9.0); - }); - // deps back to [a_1, a_2] - b_1_2 = constant_op.constant(10.0); - }); - // deps back to [a_1] - b_1 = constant_op.constant(11.0); - tf_with(g.control_dependencies(null), ctrl6 => - { - // deps are None again - b_none2 = constant_op.constant(12.0); - }); - }); - // Note assertItemsEqual(given, expected), expected and given parameters should be swapped below - assertItemsEqual(new[] { a_3.op, a_4.op }, b_3_4.op.control_inputs); - assertItemsEqual(new[] { a_3.op }, b_3.op.control_inputs); - assertItemsEqual(new object[0], b_none.op.control_inputs); - assertItemsEqual(new[] { a_1.op, a_2.op }, b_1_2.op.control_inputs); - assertItemsEqual(new[] { a_1.op }, b_1.op.control_inputs); - assertItemsEqual(new object[0], b_none2.op.control_inputs); - } - - [TestMethod] - public void TestComplex() - { - var g = tf.Graph().as_default(); - // Usage pattern: - // * Nodes a_i are constants defined at the outermost scope, and are used - // as control inputs for the ith nested scope. - // * Nodes b_i are defined as Mul(a_3, a_4) at each scope. - // * Nodes c_i are defined as Mul(a_1, b_1) at each scope. - // * Nodes d_i are defined as Mul(b_i, c_i) at each scope. - // * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1. - var a_1 = constant_op.constant(1.0); - var a_2 = constant_op.constant(2.0); - var a_3 = constant_op.constant(3.0); - var a_4 = constant_op.constant(4.0); - Operation b_1 = null, b_2 = null, b_3 = null, b_4 = null; - Operation c_1 = null, c_2 = null, c_3 = null, c_4 = null; - Operation d_1 = null, d_2 = null, d_3 = null, d_4 = null; - Operation e_1 = null, e_2 = null, e_3 = null, e_4 = null; - tf_with(g.control_dependencies(new[] { a_1 }), ctrl1 => - { - b_1 = tf.multiply(a_3, a_4); - c_1 = tf.multiply(a_1, b_1.output); - d_1 = tf.multiply(b_1.output, c_1.output); - e_1 = constant_op.constant(5.0); - tf_with(g.control_dependencies(new[] { a_2 }), ctrl2 => - { - b_2 = tf.multiply(a_3, a_4); - c_2 = tf.multiply(a_1, b_1.output); - d_2 = tf.multiply(b_2.output, c_2.output); - e_2 = tf.multiply(e_1.output, e_1.output); - tf_with(g.control_dependencies(new[] { a_3 }), ctrl3 => - { - b_3 = tf.multiply(a_3, a_4); - c_3 = tf.multiply(a_1, b_1.output); - d_3 = tf.multiply(b_3.output, c_3.output); - e_3 = tf.multiply(e_2.output, e_2.output); - tf_with(g.control_dependencies(new[] { a_4 }), ctrl4 => - { - b_4 = tf.multiply(a_3, a_4); - c_4 = tf.multiply(a_1, b_1.output); - d_4 = tf.multiply(b_4.output, c_4.output); - e_4 = tf.multiply(e_3.output, e_3.output); - }); - }); - }); - }); - - // Note assertItemsEqual(given, expected), expected and given parameters should be swapped below - assertItemsEqual(new[] { a_1.op }, b_1.op.control_inputs); - assertItemsEqual(new[] { a_1.op, a_2.op }, b_2.op.control_inputs); - assertItemsEqual(new[] { a_1.op, a_2.op }, b_3.op.control_inputs); - assertItemsEqual(new[] { a_1.op, a_2.op }, b_4.op.control_inputs); - - assertItemsEqual(new object[0], c_1.op.control_inputs); - assertItemsEqual(new[] { a_2.op }, c_2.op.control_inputs); - assertItemsEqual(new[] { a_2.op, a_3.op }, c_3.op.control_inputs); - assertItemsEqual(new[] { a_2.op, a_3.op, a_4.op }, c_4.op.control_inputs); - - assertItemsEqual(new object[0], d_1.op.control_inputs); - assertItemsEqual(new object[0], d_2.op.control_inputs); - assertItemsEqual(new object[0], d_3.op.control_inputs); - assertItemsEqual(new object[0], d_4.op.control_inputs); - - assertItemsEqual(new[] { a_1.op }, e_1.op.control_inputs); - assertItemsEqual(new[] { a_2.op }, e_2.op.control_inputs); - assertItemsEqual(new[] { a_3.op }, e_3.op.control_inputs); - assertItemsEqual(new[] { a_4.op }, e_4.op.control_inputs); - } - - [Ignore("Don't know how to create an operation with two outputs")] - [TestMethod] - public void TestRepeatedDependency() - { - /* - def testRepeatedDependency(self): - g = ops.Graph() - a = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32]) - a_0, a_1 = a.outputs - with g.control_dependencies([a_0]): - b = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - with g.control_dependencies([a_1]): - c = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - - self.assertEqual(b.op.control_inputs, [a]) - self.assertEqual(c.op.control_inputs, [a]) - - */ - } - - [TestMethod] - public void TestNoControlDependencyWithDataDependency() - { - var g = tf.Graph().as_default(); - Operation b = null; - var a = constant_op.constant(100.0); - tf_with(g.control_dependencies(new[] { a }), ctrl1 => - { - b = array_ops.identity(a); - }); - Assert.AreEqual(0, b.op.control_inputs.Length); - } - - } -} diff --git a/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs b/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs deleted file mode 100644 index 00200f9a..00000000 --- a/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs +++ /dev/null @@ -1,222 +0,0 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; -using System; -using System.Linq; -using Tensorflow; -using Tensorflow.Operations; -using Tensorflow.UnitTest; -using static Tensorflow.Binding; - -namespace TensorFlowNET.UnitTest.ops_test -{ - /// - /// excerpt of tensorflow/python/framework/ops_test.py - /// # These cases test the private Graph._create_op_from_tf_operation - /// # method. Arguably we should only test the public APIs that depend on this - /// # method. However, this logic is complex and tricky, and it can be difficult to - /// # ascertain if we have adequate coverage (e.g. a graph may run successfully if - /// # the control flow context isn't set properly, but a more complicated use case - /// # that might not be obvious to test will fail). Thus we instead explicitly test - /// # the low-level behavior. - /// - [Ignore] - [TestClass] - public class CreateOpFromTfOperationTest : GraphModeTestBase - { - - [TestMethod] - public void TestShape() - { - using (var g = tf.Graph().as_default()) - { - var x = constant_op.constant(new[,] { { 1, 2, 3 }, { 4, 5, 6 } }); - var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] { x }, new Operation[0]); - var op = g._create_op_from_tf_operation(c_op); - - Assert.AreEqual("myop", op.name); - Assert.AreEqual("Identity", op.type); - Assert.AreEqual(1, len(op.outputs)); - assertItemsEqual(new[] { 2, 3 }, op.outputs[0].shape); - } - } - - [TestMethod] - public void TestUniqueName() - { - var graph = tf.Graph().as_default(); - //var (c_op,op_desc) = ops._create_c_op(g, ops._NodeDef("Const", "myop"), new Tensor[0], new Operation[0]); - //var (c_op2, op_desc1) = ops._create_c_op(g, ops._NodeDef("Const", "myop_1"), new Tensor[0], new Operation[0]); - //var op = g._create_op_from_tf_operation(c_op); - //var op2 = g._create_op_from_tf_operation(c_op2); - var op = constant_op.constant(0, name: "myop").op; - var op2 = constant_op.constant(0, name: "myop_1").op; - - // Create ops with same names as op1 and op2. We expect the new names to be - // uniquified. - var op3 = constant_op.constant(0, name: "myop").op; - var op4 = constant_op.constant(0, name: "myop_1").op; - - self.assertEqual(op.name, "myop"); - self.assertEqual(op2.name, "myop_1"); - self.assertEqual(op3.name, "myop_2"); - self.assertEqual(op4.name, "myop_1_1"); - } - - [Ignore("need tesnroflow expose UpdateEdge API")] - [TestMethod] - public void TestCond() - { - var g = tf.Graph().as_default(); - var x = constant_op.constant(10); - - var true_fn = new Func(() => - { - var c_op = ops._create_c_op(g, ops._NodeDef("Identity", "cond/myop"), new[] { x }, new Operation[0]); - var new_ops = g._add_new_tf_operations(); - self.assertEqual(len(new_ops), 1); - return x; - }); - - control_flow_ops.cond(x < 10, true_fn, () => x); - - var op = g.get_operation_by_name("cond/myop"); - - //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta.txt", as_text:true); - //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); - - self.assertIsNotNone(op); - self.assertEqual(op.name, "cond/myop"); - self.assertEqual(op.type, "Identity"); - //self.assertEqual(op.outputs, new object[0]); - var op_input = op.inputs[0].op; - self.assertEqual(op_input.type, "Switch"); - self.assertEqual(op_input.inputs[0].name, x.name); - self.assertEqual(op.graph, g); - self.assertIsNotNone(op._get_control_flow_context()); - var cond_text = op._get_control_flow_context() as ControlFlowContext; - self.assertEqual(cond_text.Name, "cond/cond_text"); - } - - [Ignore("Todo: Port")] - [TestMethod] - public void TestWhileLoop() - { - var graph = tf.Graph().as_default(); - Operation x = null; - x = constant_op.constant(42); - var body = new Func(i => - { - ops._create_c_op(ops.get_default_graph(), ops._NodeDef("Identity", "myloop/myop"), new[] { x.output }, - new Operation[0]); - var new_ops = graph._add_new_tf_operations(); - self.assertEqual(len(new_ops), 1); - return i; - }); - // TODO: port control_flow_ops.while_loop - //control_flow_ops.while_loop( i => i < 10, body, new int[]{0}, name = "myloop"); - var op = graph.get_operation_by_name("myloop/myop"); - self.assertIsNotNone(op); - self.assertEqual(op.name, "myloop/myop"); - self.assertEqual(op.type, "Identity"); - self.assertEqual(op.outputs.Length, 0); - var op_input = op.inputs[0].op; - self.assertEqual(op_input.type, "Enter"); - self.assertItemsEqual(op_input.inputs.OfType().ToArray(), new[] { x }); - self.assertEqual(op.graph, graph); - self.assertIsNotNone(op._get_control_flow_context()); - self.assertEqual(((ControlFlowContext)op._get_control_flow_context()).Name, "myloop/while_context"); - /* - @test_util.run_v1_only("b/120545219") - def testWhileLoop(self): - g = ops.Graph() - with g.as_default(): - x = test_ops.int_output() - - def body(i): - ops._create_c_op(ops.get_default_graph(), - ops._NodeDef("IntInput", "myloop/myop"), [x], []) - new_ops = g._add_new_tf_operations() - self.assertEqual(len(new_ops), 1) - return i - - control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") - - op = g.get_operation_by_name("myloop/myop") - self.assertIsNotNone(op) - self.assertEqual(op.name, "myloop/myop") - self.assertEqual(op.type, "IntInput") - self.assertEqual(op.outputs, []) - op_input = op.inputs[0].op - self.assertEqual(op_input.type, "Enter") - self.assertEqual(list(op_input.inputs), [x]) - self.assertEqual(op.graph, g) - # pylint: disable=protected-access - self.assertIsNotNone(op._get_control_flow_context()) - self.assertEqual(op._get_control_flow_context().name, - "myloop/while_context") - # pylint: enable=protected-access - */ - } - - [Ignore("Todo: Port")] - [TestMethod] - public void TestWhileLoopWithInternalControlDep() - { - /* -@test_util.run_v1_only("b/120545219") - def testWhileLoopWithInternalControlDep(self): - g = ops.Graph() - with g.as_default(): - x = test_ops.int_output() - - def body(i): - c = constant_op.constant(1.0, name="c") - ops._create_c_op(ops.get_default_graph(), - ops._NodeDef("IntInput", "myloop/myop"), [x], []) - with ops.control_dependencies([c]): - new_ops = g._add_new_tf_operations() - self.assertEqual(len(new_ops), 1) - return i - - control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") - - op = g.get_operation_by_name("myloop/myop") - self.assertIsNotNone(op) - c = g.get_operation_by_name("myloop/c") - self.assertIsNotNone(c) - # Internal control dep is preserved - self.assertEqual(op.control_inputs, [c]) - */ - } - - [Ignore("Todo: Port")] - [TestMethod] - public void TestWhileLoopWithExternalControlDep() - { - /* - @test_util.run_v1_only("b/120545219") - def testWhileLoopWithExternalControlDep(self): - g = ops.Graph() - with g.as_default(): - x = test_ops.int_output() - c = constant_op.constant(1.0) - - def body(i): - ops._create_c_op(ops.get_default_graph(), - ops._NodeDef("IntInput", "myloop/myop"), [x], []) - with ops.control_dependencies([c]): - new_ops = g._add_new_tf_operations() - self.assertEqual(len(new_ops), 1) - return i - - control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") - - op = g.get_operation_by_name("myloop/myop") - self.assertIsNotNone(op) - # External control dep is removed and replaced with internal control dep - self.assertNotEqual(op.control_inputs[0], c.op) - self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context()) - */ - } - - } -} diff --git a/test/TensorFlowNET.UnitTest/ops_test/GraphTest.cs b/test/TensorFlowNET.UnitTest/ops_test/GraphTest.cs deleted file mode 100644 index d32c188f..00000000 --- a/test/TensorFlowNET.UnitTest/ops_test/GraphTest.cs +++ /dev/null @@ -1,196 +0,0 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; -using Tensorflow; -using Tensorflow.UnitTest; - -namespace TensorFlowNET.UnitTest.ops_test -{ - /// - /// excerpt of tensorflow/python/framework/ops_test.py - /// - [TestClass] - public class GraphTest : GraphModeTestBase - { - [TestInitialize] - public void SetUp() - { - ops.reset_default_graph(); - } - - [TestCleanup] - public void TearDown() - { - ops.reset_default_graph(); - } - - private void _AssertDefault(Graph expected) - { - Assert.AreSame(ops.get_default_graph(), expected); - } - - - [Ignore("Todo: Port")] - [TestMethod] - public void testResetDefaultGraphNesting() - { - /* - def testResetDefaultGraphNesting(self): - g0 = ops.Graph() - with self.assertRaises(AssertionError): - with g0.as_default(): - ops.reset_default_graph() - */ - } - - [Ignore("Todo: Port")] - [TestMethod] - public void testGraphContextManagerCancelsEager() - { - /* - def testGraphContextManagerCancelsEager(self): - with context.eager_mode(): - with ops.Graph().as_default(): - self.assertFalse(context.executing_eagerly()) - */ - } - - - [Ignore("Todo: Port")] - [TestMethod] - public void testGraphContextManager() - { - /* - def testGraphContextManager(self): - g0 = ops.Graph() - with g0.as_default() as g1: - self.assertIs(g0, g1) - */ - } - - [Ignore("Todo: Port")] - [TestMethod] - public void testDefaultGraph() - { - /* - def testDefaultGraph(self): - orig = ops.get_default_graph() - self._AssertDefault(orig) - g0 = ops.Graph() - self._AssertDefault(orig) - context_manager_0 = g0.as_default() - self._AssertDefault(orig) - with context_manager_0 as g0: - self._AssertDefault(g0) - with ops.Graph().as_default() as g1: - self._AssertDefault(g1) - self._AssertDefault(g0) - self._AssertDefault(orig) - */ - } - - [Ignore("Todo: Port")] - [TestMethod] - public void testPreventFeeding() - { - /* - def testPreventFeeding(self): - g = ops.Graph() - a = constant_op.constant(2.0) - self.assertTrue(g.is_feedable(a)) - g.prevent_feeding(a) - self.assertFalse(g.is_feedable(a)) - */ - } - - - [Ignore("Todo: Port")] - [TestMethod] - public void testAsGraphElementConversions() - { - /* - def testAsGraphElementConversions(self): - - class ConvertibleObj(object): - - def _as_graph_element(self): - return "FloatOutput:0" - - class NonConvertibleObj(object): - - pass - - g = ops.Graph() - a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - self.assertEqual(a, g.as_graph_element(ConvertibleObj())) - with self.assertRaises(TypeError): - g.as_graph_element(NonConvertibleObj()) - */ - } - - [Ignore("Todo: Port")] - [TestMethod] - public void testGarbageCollected() - { - /* - # Regression test against creating custom __del__ functions in classes - # involved in cyclic references, e.g. Graph and Operation. (Python won't gc - # cycles that require calling a __del__ method, because the __del__ method can - # theoretically increase the object's refcount to "save" it from gc, and any - # already-deleted objects in the cycle would have be to restored.) - def testGarbageCollected(self): - # Create a graph we can delete and a weak reference to monitor if it's gc'd - g = ops.Graph() - g_ref = weakref.ref(g) - # Create some ops - with g.as_default(): - a = constant_op.constant(2.0) - b = constant_op.constant(3.0) - c = math_ops.add(a, b) - # Create a session we can delete - with session.Session(graph=g) as sess: - self.evaluate(c) - # Delete all references and trigger gc - del g - del a - del b - del c - del sess - gc.collect() - self.assertIsNone(g_ref()) - */ - } - - [Ignore("Todo: Port")] - [TestMethod] - public void testRunnableAfterInvalidShape() - { - /* - def testRunnableAfterInvalidShape(self): - with ops.Graph().as_default(): - with self.assertRaises(ValueError): - math_ops.add([1, 2], [1, 2, 3]) - a = constant_op.constant(1) - with session.Session() as sess: - self.evaluate(a) - */ - } - - [Ignore("Todo: Port")] - [TestMethod] - public void testRunnableAfterInvalidShapeWithKernelLabelMap() - { - /* - def testRunnableAfterInvalidShapeWithKernelLabelMap(self): - g = ops.Graph() - with g.as_default(): - with g._kernel_label_map({"KernelLabelRequired": "overload_1"}): - with self.assertRaises(ValueError): - test_ops.kernel_label_required(1) - a = constant_op.constant(1) - with session.Session() as sess: - self.evaluate(a) - */ - } - - - } -} diff --git a/test/TensorFlowNET.UnitTest/ops_test/ops_test_r1.13.py b/test/TensorFlowNET.UnitTest/ops_test/ops_test_r1.13.py deleted file mode 100644 index 2d7ee1a9..00000000 --- a/test/TensorFlowNET.UnitTest/ops_test/ops_test_r1.13.py +++ /dev/null @@ -1,3014 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tensorflow.python.framework.ops.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import gc -import os -import threading -import weakref - -from tensorflow.core.framework import attr_value_pb2 -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.client import session -from tensorflow.python.eager import context -from tensorflow.python.eager import function as eager_function -from tensorflow.python.framework import common_shapes -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import device as pydev -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import function -from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util -from tensorflow.python.framework import test_ops -from tensorflow.python.framework import test_util -from tensorflow.python.framework import versions -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import resources -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -import tensorflow.python.ops.gradients # pylint: disable=unused-import -from tensorflow.python.platform import googletest -from tensorflow.python.util import compat - -ops._set_call_cpp_shape_fn(common_shapes.call_cpp_shape_fn) - - -class ResourceTest(test_util.TensorFlowTestCase): - - @test_util.run_deprecated_v1 - def testBuildGraph(self): - with self.cached_session(): - pt = test_ops.stub_resource_handle_op(container="a", shared_name="b") - test_ops.resource_create_op(pt).run() - - @test_util.run_deprecated_v1 - def testInitialize(self): - with self.cached_session(): - handle = test_ops.stub_resource_handle_op(container="a", shared_name="b") - resources.register_resource( - handle=handle, - create_op=test_ops.resource_create_op(handle), - is_initialized_op=test_ops.resource_initialized_op(handle)) - self.assertEquals( - len( - resources.report_uninitialized_resources( - resources.shared_resources()).eval()), 1) - resources.initialize_resources(resources.shared_resources()).run() - self.assertEquals( - len( - resources.report_uninitialized_resources( - resources.shared_resources()).eval()), 0) - - -class TensorAndShapeTest(test_util.TensorFlowTestCase): - - def testShape(self): - op = ops.Operation( - ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]) - t = op.outputs[0] - self.assertEqual(tensor_shape.unknown_shape(), t.get_shape()) - t.set_shape([1, 2, 3]) - self.assertEqual([1, 2, 3], t.get_shape()) - - def testIterable(self): - op = ops.Operation( - ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]) - t = op.outputs[0] - self.assertTrue(isinstance(t, ops.Tensor)) - with self.assertRaisesRegexp(TypeError, "iter"): - for _ in t: - pass - - def testAddShape(self): - with self.cached_session(): - a = array_ops.zeros([2, 3]) - b = array_ops.ones([1, 3]) - c = a + b - self.assertEqual([2, 3], c.shape) - - @test_util.run_deprecated_v1 - def testUnknownDim(self): - with self.cached_session(): - a = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3]) - b = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3]) - c = a + b - self.assertEqual([2, None, 3], c.shape.as_list()) - - @test_util.run_deprecated_v1 - def testUnknownShape(self): - with self.cached_session(): - a = array_ops.placeholder(dtype=dtypes.float32, shape=None) - b = array_ops.ones([1, 3]) - c = a + b - self.assertEqual(tensor_shape.unknown_shape(), c.shape) - - @test_util.run_deprecated_v1 - def testScalarShape(self): - with self.cached_session(): - a = array_ops.placeholder(dtype=dtypes.float32, shape=[]) - b = array_ops.ones([]) - c = a + b - self.assertEqual(tensor_shape.scalar(), c.shape) - - @test_util.run_deprecated_v1 - def testShapeFunctionError(self): - with self.cached_session(): - a = array_ops.ones([1, 2, 3]) - b = array_ops.ones([4, 5, 6]) - with self.assertRaisesRegexp( - ValueError, - r"Dimensions must be equal, but are 2 and 5 for 'add' \(op: 'Add'\) " - r"with input shapes: \[1,2,3\], \[4,5,6\]."): - _ = a + b - - -class IndexedSlicesTest(test_util.TensorFlowTestCase): - - @test_util.run_in_graph_and_eager_modes - def testToTensor(self): - values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) - indices = constant_op.constant([0, 2]) - dense_shape = constant_op.constant([3, 2]) - x = ops.IndexedSlices(values, indices, dense_shape) - tensor = ops.convert_to_tensor(x, name="tensor") - self.assertAllEqual(self.evaluate(tensor), [[2, 3], [0, 0], [5, 7]]) - - @test_util.run_deprecated_v1 - def testNegation(self): - with self.cached_session(): - values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) - indices = constant_op.constant([0, 2]) - x = -ops.IndexedSlices(values, indices) - self.assertAllEqual(x.values.eval(), [[-2, -3], [-5, -7]]) - self.assertAllEqual(x.indices.eval(), [0, 2]) - - @test_util.run_deprecated_v1 - def testScalarMul(self): - with self.cached_session(): - values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) - indices = constant_op.constant([0, 2]) - x = math_ops.scalar_mul(-2, ops.IndexedSlices(values, indices)) - self.assertAllEqual(x.values.eval(), [[-4, -6], [-10, -14]]) - self.assertAllEqual(x.indices.eval(), [0, 2]) - - -class NodeDefConstructorTest(test_util.TensorFlowTestCase): - - def testNoArgs(self): - nodedef = ops._NodeDef("None", "bar") - self.assertProtoEquals("op: 'None' name: 'bar'", nodedef) - - def testArgs(self): - nodedef = ops._NodeDef("foo", "bar", device="/device:baz:*") - self.assertProtoEquals("op:'foo' name:'bar' device:'/device:baz:*'", - nodedef) - nodedef = ops._NodeDef("foo", "bar", device=pydev.DeviceSpec(job="j")) - self.assertProtoEquals("op:'foo' name:'bar' device:'/job:j'", nodedef) - - -def _apply_op(g, *args, **kwargs): - op = g.create_op(*args, **kwargs) - if len(op.outputs) == 1: - return op.outputs[0] - else: - return op.outputs - - -class OperationTest(test_util.TensorFlowTestCase): - - @test_util.run_deprecated_v1 - def testNoInputs(self): - op = test_ops.float_output_string_output(name="myop").a.op - self.assertEqual(2, len(op.values())) - self.assertEqual(0, len(op.inputs)) - self.assertEqual("myop", op.name) - - float_t, label_str_t = op.values() - self.assertEqual(dtypes.float32, float_t.dtype) - self.assertEqual(op, float_t.op) - self.assertEqual(0, float_t._value_index) - self.assertEqual(0, len(float_t.consumers())) - self.assertEqual("myop", float_t._as_node_def_input()) - - self.assertEqual(dtypes.string, label_str_t.dtype) - self.assertEqual(op, label_str_t.op) - self.assertEqual(1, label_str_t._value_index) - self.assertEqual(0, len(label_str_t.consumers())) - self.assertEqual("myop:1", label_str_t._as_node_def_input()) - - self.assertProtoEquals("op:'FloatOutputStringOutput' name:'myop'", - op.node_def) - - @test_util.run_deprecated_v1 - def testNoOutputs(self): - op1 = test_ops.float_output(name="myop1").op - float_t, = op1.values() - op2 = test_ops.float_input(float_t, name="myop2") - self.assertEqual(0, len(op2.values())) - self.assertEqual(1, len(op2.inputs)) - self.assertIs(float_t, op2.inputs[0]) - - self.assertEqual(1, len(float_t.consumers())) - self.assertEqual(op2, float_t.consumers()[0]) - - self.assertProtoEquals("op:'FloatOutput' name:'myop1'", op1.node_def) - self.assertProtoEquals("op:'FloatInput' name:'myop2' input:'myop1'", - op2.node_def) - - @test_util.run_deprecated_v1 - def testInputsAndOutputs(self): - op1 = test_ops.float_output(name="myop1").op - self.assertEqual(1, len(op1.values())) - float1_t, = op1.values() - - op2 = test_ops.float_output_string_output(name="myop2").a.op - self.assertEqual(2, len(op2.values())) - float2_t, label2_str_t = op2.values() - - # Note that we consume label2_str_t twice here. - op3 = test_ops.foo2(float1_t, label2_str_t, label2_str_t, name="myop3").d.op - self.assertEqual(2, len(op3.values())) - - self.assertEqual(1, len(float1_t.consumers())) - self.assertEqual(op3, float1_t.consumers()[0]) - - self.assertEqual(0, len(float2_t.consumers())) - - self.assertEqual(2, len(label2_str_t.consumers())) - self.assertEqual(op3, label2_str_t.consumers()[0]) - self.assertEqual(op3, label2_str_t.consumers()[1]) - - self.assertProtoEquals(""" - op:'Foo2' name:'myop3' - input:'myop1' input:'myop2:1' input:'myop2:1' - """, op3.node_def) - - def testDeviceFromNodeDef(self): - op = ops.Operation( - ops._NodeDef("None", "myop", device="/job:goo/device:GPU:0"), - ops.Graph(), [], []) - self.assertEqual("/job:goo/device:GPU:0", op.device) - - def testDeviceObject(self): - op = ops.Operation(ops._NodeDef("None", "myop"), ops.Graph(), [], []) - op._set_device("/job:goo/device:GPU:0") - self.assertProtoEquals( - "op:'None' name:'myop' device:'/job:goo/device:GPU:0' ", op.node_def) - op = ops.Operation(ops._NodeDef("None", "op2"), ops.Graph(), [], []) - op._set_device( - pydev.DeviceSpec( - job="muu", device_type="CPU", device_index=0)) - self.assertProtoEquals( - "op:'None' name:'op2' device:'/job:muu/device:CPU:0'", op.node_def) - - def testReferenceInput(self): - g = ops.Graph() - op1 = ops.Operation( - ops._NodeDef("RefOutputFloatOutput", "op1"), g, [], - [dtypes.float32_ref, dtypes.float32]) - self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def) - self.assertEquals([], list(op1.inputs)) - ref_t, nonref_t = op1.values() - # NOTE(mrry): Must specify input_types to preserve ref-typed input. - op2 = ops.Operation( - ops._NodeDef("RefInputFloatInput", "op2"), - g, [ref_t, nonref_t], [], - input_types=[dtypes.float32_ref, dtypes.float32]) - self.assertProtoEquals( - "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'", - op2.node_def) - self.assertEquals([ref_t, nonref_t], list(op2.inputs)) - op3 = ops.Operation( - ops._NodeDef("TwoFloatInputs", "op3"), g, [ref_t, nonref_t], []) - self.assertProtoEquals( - "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'", - op3.node_def) - - def testInvalidNames(self): - g = ops.Graph() - with self.assertRaises(ValueError): - ops.Operation(ops._NodeDef("op", ""), g) - with self.assertRaises(ValueError): - ops.Operation(ops._NodeDef("op", "_invalid"), g) - with self.assertRaises(ValueError): - ops.Operation(ops._NodeDef("op", "-invalid"), g) - with self.assertRaises(ValueError): - ops.Operation(ops._NodeDef("op", "/invalid"), g) - with self.assertRaises(ValueError): - ops.Operation(ops._NodeDef("op", "invalid:0"), g) - - @test_util.run_deprecated_v1 - def testNoShapeFunction(self): - op = test_ops.a() - self.assertEqual(tensor_shape.unknown_shape(), op.get_shape()) - - @test_util.run_in_graph_and_eager_modes - def testConvertToTensorNestedArray(self): - values = [[2], [3], [5], [7]] - tensor = ops.convert_to_tensor(values) - self.assertAllEqual((4, 1), tensor.get_shape().as_list()) - self.assertAllEqual(values, self.evaluate(tensor)) - - def testShapeTuple(self): - with self.cached_session(): - c = constant_op.constant(1) - self.assertEqual(c._shape_tuple(), ()) # pylint: disable=protected-access - - def testConvertToTensorEager(self): - with context.eager_mode(): - t = constant_op.constant(1) - self.assertTrue(isinstance(t, ops.EagerTensor)) - converted = ops.convert_to_tensor(t) - self.assertTrue(isinstance(converted, ops.EagerTensor)) - converted = ops.convert_to_tensor(1) - self.assertTrue(isinstance(converted, ops.EagerTensor)) - - @test_util.run_in_graph_and_eager_modes - def testConvertToTensorNestedTuple(self): - values = ((2,), (3,), (5,), (7,)) - tensor = ops.convert_to_tensor(values) - self.assertAllEqual((4, 1), tensor.get_shape().as_list()) - self.assertAllEqual(values, self.evaluate(ops.convert_to_tensor(values))) - - @test_util.run_in_graph_and_eager_modes - def testConvertToTensorNestedTensors(self): - values = ((2,), (3,), (5,), (7,)) - tensor = ops.convert_to_tensor( - [constant_op.constant(row) for row in values]) - self.assertAllEqual((4, 1), tensor.get_shape().as_list()) - self.assertAllEqual(values, self.evaluate(tensor)) - tensor = ops.convert_to_tensor( - [[constant_op.constant(v) for v in row] for row in values]) - self.assertAllEqual((4, 1), tensor.get_shape().as_list()) - self.assertAllEqual(values, self.evaluate(tensor)) - - @test_util.run_in_graph_and_eager_modes - def testConvertToTensorNestedMix(self): - values = ([2], (3,), [constant_op.constant(5)], constant_op.constant([7])) - tensor = ops.convert_to_tensor(values) - self.assertAllEqual((4, 1), tensor.get_shape().as_list()) - self.assertAllEqual(((2,), (3,), (5,), (7,)), self.evaluate(tensor)) - - @test_util.run_in_graph_and_eager_modes - def testConvertToTensorPreferred(self): - values = [2, 3, 5, 7] - tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.float32) - self.assertEqual(dtypes.float32, tensor.dtype) - - # Convert empty tensor to anything. - values = [] - tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64) - self.assertEqual(dtypes.int64, tensor.dtype) - - # The preferred dtype is a type error and will convert to - # float32 instead. - values = [1.23] - tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64) - self.assertEqual(dtypes.float32, tensor.dtype) - - @test_util.run_in_graph_and_eager_modes - def testConvertToInvalidTensorType(self): - with self.assertRaises(TypeError): - # Forcing an invalid dtype should fail with a type error. - values = [1.23] - ops.convert_to_tensor(values, dtype=dtypes.int64) - - @test_util.run_in_graph_and_eager_modes - def testConvertToTensorFromInvalidTensor(self): - tensor = constant_op.constant(42.0, dtype=dtypes.float32) - with self.assertRaises(ValueError): - ops.convert_to_tensor(tensor, dtype=dtypes.int32) - - @test_util.run_deprecated_v1 - def testNoConvert(self): - # Operation cannot be converted to Tensor. - op = control_flow_ops.no_op() - with self.assertRaisesRegexp(TypeError, - r"Can't convert Operation '.*' to Tensor"): - ops.convert_to_tensor(op) - - def testStr(self): - node_def = ops._NodeDef("None", "op1") - op = ops.Operation(node_def, ops.Graph(), [], [dtypes.float32]) - self.assertEqual(str(node_def), str(op)) - - def testRepr(self): - op = ops.Operation( - ops._NodeDef("None", "op1"), ops.Graph(), [], [dtypes.float32]) - self.assertEqual("", repr(op)) - - @test_util.run_deprecated_v1 - def testGetAttr(self): - op = test_ops.default_attrs() - self.assertEqual(op.get_attr("string_val"), b"abc") - self.assertEqual(op.get_attr("string_list_val"), [b"abc", b""]) - self.assertEqual(op.get_attr("int_val"), 123) - self.assertEqual(op.get_attr("int_list_val"), [1, 2, 3]) - self.assertEqual(op.get_attr("float_val"), 10.0) - self.assertEqual(op.get_attr("float_list_val"), [10.0]) - self.assertEqual(op.get_attr("bool_val"), True) - self.assertEqual(op.get_attr("bool_list_val"), [True, False]) - self.assertEqual(op.get_attr("shape_val"), - tensor_shape.as_shape([2, 1]).as_proto()) - self.assertEqual(op.get_attr("shape_list_val"), - [tensor_shape.as_shape([]).as_proto(), - tensor_shape.as_shape([1]).as_proto()]) - self.assertEqual(op.get_attr("tensor_val"), - tensor_util.make_tensor_proto(1, dtypes.int32)) - self.assertEqual(op.get_attr("tensor_list_val"), - [tensor_util.make_tensor_proto(1, dtypes.int32)]) - - type_val = op.get_attr("type_val") - # First check that type_val is a DType, because the assertEquals will work - # no matter what since DType overrides __eq__ - self.assertIsInstance(type_val, dtypes.DType) - self.assertEqual(type_val, dtypes.int32) - - type_list_val = op.get_attr("type_list_val") - self.assertTrue(all(isinstance(x, dtypes.DType) for x in type_list_val)) - self.assertEqual(type_list_val, [dtypes.int32, dtypes.float32]) - - @function.Defun(dtypes.float32, func_name="MyFunc") - def func(x): - return x - - op = test_ops.func_attr(func) - self.assertEqual(op.get_attr("f"), - attr_value_pb2.NameAttrList(name="MyFunc")) - - # Try fetching missing attr - with self.assertRaisesRegexp( - ValueError, "Operation 'FuncAttr' has no attr named 'FakeAttr'."): - op.get_attr("FakeAttr") - - # TODO(b/65162920): remove this test when users who are directly mutating the - # node_def have been updated to proper usage. - @test_util.run_deprecated_v1 - def testSetAttr(self): - op = test_ops.int_attr().op - op._set_attr("foo", attr_value_pb2.AttrValue(i=2)) - # TODO(skyewm): add node_def check - self.assertEqual(op.get_attr("foo"), 2) - - # TODO(nolivia): test all error cases - def testAddControlInput(self): - with ops.Graph().as_default(): - x = constant_op.constant(1).op - y = constant_op.constant(2).op - z = constant_op.constant(3).op - z._add_control_input(x) # pylint: disable=protected-access - self.assertEqual(z.control_inputs, [x]) - z._add_control_input(x) # pylint: disable=protected-access - self.assertEqual(z.control_inputs, [x]) - z._add_control_inputs([x, y, y]) # pylint: disable=protected-access - self.assertEqual(z.control_inputs, [x, y]) - self.assertEqual(x._control_outputs, [z]) - - @test_util.run_deprecated_v1 - def testRemoveAllControlInputs(self): - a = constant_op.constant(1) - with ops.control_dependencies([a]): - b = constant_op.constant(2) - c = constant_op.constant(3) - d = constant_op.constant(4) - e = constant_op.constant(5) - with ops.control_dependencies([a, c]): - f = d + e - - self.assertEqual(a.op.control_inputs, []) - self.assertEqual(b.op.control_inputs, [a.op]) - self.assertEqual(f.op.control_inputs, [a.op, c.op]) - - a.op._remove_all_control_inputs() # pylint: disable=protected-access - self.assertEqual(a.op.control_inputs, []) - - b.op._remove_all_control_inputs() # pylint: disable=protected-access - self.assertEqual(b.op.control_inputs, []) - - f.op._remove_all_control_inputs() # pylint: disable=protected-access - self.assertEqual(f.op.control_inputs, []) - self.assertEqual(list(f.op.inputs), [d, e]) - - @test_util.run_deprecated_v1 - def testControlInputCycle(self): - graph = ops.Graph() - with graph.as_default(): - z = constant_op.constant(0) - x = constant_op.constant(1) - y = constant_op.constant(2) - y.op._add_control_input(z.op) # pylint: disable=protected-access - y.op._add_control_input(x.op) # pylint: disable=protected-access - x.op._add_control_input(y.op) # pylint: disable=protected-access - with self.session(graph=graph) as sess: - with self.assertRaisesRegexp( - errors.InvalidArgumentError, - "Graph is invalid, contains a cycle with 2 nodes"): - self.evaluate(x) - - def testUpdateInput(self): - g = ops.Graph() - with g.as_default(): - x = constant_op.constant(1) - y = constant_op.constant(2) - z = x + y - - z.op._update_input(0, y) # pylint: disable=protected-access - self.assertEquals(list(z.op.inputs), [y, y]) - self.assertEquals(x.consumers(), []) - self.assertEquals(y.consumers(), [z.op, z.op]) - with session.Session(graph=g) as sess: - self.assertEquals(self.evaluate(z), 4) - - z.op._update_input(0, x) # pylint: disable=protected-access - self.assertEquals(list(z.op.inputs), [x, y]) - self.assertEquals(x.consumers(), [z.op]) - self.assertEquals(y.consumers(), [z.op]) - with session.Session(graph=g) as sess: - self.assertEquals(self.evaluate(z), 3) - - z.op._update_input(1, y) # pylint: disable=protected-access - self.assertEquals(list(z.op.inputs), [x, y]) - self.assertEquals(x.consumers(), [z.op]) - self.assertEquals(y.consumers(), [z.op]) - with session.Session(graph=g) as sess: - self.assertEquals(self.evaluate(z), 3) - - def testUpdateInputGraphError(self): - g_0 = ops.Graph() - g_1 = ops.Graph() - with g_0.as_default(): - x = constant_op.constant(1) - with g_1.as_default(): - y = constant_op.constant(2) - z = y * 2 - with self.assertRaisesRegexp(ValueError, "must be from the same graph"): - z.op._update_input(0, x) # pylint: disable=protected-access - - def testUpdateInputTypeError(self): - g = ops.Graph() - with g.as_default(): - w = constant_op.constant(0) - x = constant_op.constant("") - y = constant_op.constant(1) - z = y + w - z.op._update_input(0, x) # pylint: disable=protected-access - with session.Session(graph=g) as sess: - with self.assertRaisesRegexp( - errors.InvalidArgumentError, - "Input 0 of node add was passed string from Const_1:0 incompatible " - "with expected int32"): - self.evaluate(z) - - def testUpdateInputShapeError(self): - g = ops.Graph() - with g.as_default(): - w = constant_op.constant(2, shape=[3, 1]) - x = constant_op.constant(0, shape=[3, 1]) - y = constant_op.constant(1, shape=[2, 2]) - z = w + x - with self.assertRaisesRegexp( - errors.InvalidArgumentError, - r"Cannot update edge, incompatible shapes: \[2,2\] and \[3,1\]"): - z.op._update_input(0, y) # pylint: disable=protected-access - - def testUpdateInputOutOfRange(self): - g = ops.Graph() - with g.as_default(): - x = constant_op.constant(1) - with self.assertRaisesRegexp( - errors.OutOfRangeError, - r"Cannot update edge. Input index \[1\] is greater than the number of " - r"total inputs \[0\]." - ): - x.op._update_input(1, x) # pylint: disable=protected-access - - @test_util.enable_control_flow_v2 - @test_util.run_v1_only("b/120545219") - def testAddWhileInput(self): - @eager_function.defun - def test(): - output = control_flow_ops.while_loop(lambda x: x < 3, lambda x: x + 1, - [1]) - while_op = output.op.inputs[0].op - self.assertEqual(while_op.type, "While") - orig_num_inputs = len(while_op.inputs) - - # Make sure we can handle the while op having a control input. - while_op._add_control_input(constant_op.constant(0).op) - - new_input1 = constant_op.constant(1.0) - new_input2 = constant_op.constant(True) - - while_op._set_type_list_attr("T", - [t.dtype for t in while_op.inputs] + - [new_input1.dtype, new_input2.dtype]) - - while_op._add_while_inputs([new_input1, new_input2]) - # Can't add an edge beyond what's specified by "T" - with self.assertRaises(errors.OutOfRangeError): - while_op._add_while_inputs([new_input2]) - self.assertEqual(len(while_op.inputs), orig_num_inputs + 2) # pylint: disable=g-deprecated-assert - - test() - - @test_util.run_deprecated_v1 - def testOpDef(self): - x = constant_op.constant(0) - y = constant_op.constant(1) - z = x + y - - self.assertEqual(x.op.op_def.name, "Const") - self.assertEqual(len(x.op.op_def.input_arg), 0) - self.assertEqual(len(x.op.op_def.output_arg), 1) - - self.assertEqual(z.op.op_def.name, "Add") - self.assertEqual(len(z.op.op_def.input_arg), 2) - self.assertEqual(len(z.op.op_def.output_arg), 1) - - def testInputFromDifferentGraphError(self): - g_0 = ops.Graph() - g_1 = ops.Graph() - with g_0.as_default(): - x = constant_op.constant(1) - with g_1.as_default(): - y = constant_op.constant(2) - with self.assertRaisesRegexp(ValueError, "must be from the same graph"): - y * x # pylint: disable=pointless-statement - - def testInputsAreImmutable(self): - g = ops.Graph() - with g.as_default(): - x = test_ops.int_output() - op = test_ops.int_input_int_output(x, name="myop").op - with self.assertRaisesRegexp( - AttributeError, "'_InputList' object has no attribute 'append'"): - op.inputs.append(None) - - -class CreateOpTest(test_util.TensorFlowTestCase): - - def testNodeDefArgs(self): - g = ops.Graph() - op1 = g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1") - with g.device("/device:GPU:0"): - op2 = g.create_op( - "FloatOutputStringOutput", [], [dtypes.float32, dtypes.string], None, - name="myop2") - op3 = g.create_op( - "Foo3", - [list(op1.values())[0], list(op2.values())[1], list(op2.values())[0]], - [dtypes.float32, dtypes.int32], - None, - name="myop3") - self.assertDeviceEqual(None, op1.device) - self.assertDeviceEqual("/device:GPU:0", op2.device) - self.assertDeviceEqual(None, op3.device) - self.assertProtoEquals("name:'myop1' op:'FloatOutput'", op1.node_def) - self.assertProtoEquals( - "name:'myop2' op:'FloatOutputStringOutput' device:'/device:GPU:0'", - op2.node_def) - self.assertProtoEquals( - "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo3'", - op3.node_def) - - def testReferenceInput(self): - g = ops.Graph() - op1 = g.create_op( - "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32], - name="op1") - self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def) - ref_t, nonref_t = op1.values() - # NOTE(mrry): Must specify input_types to preserve ref-typed input. - op2 = g.create_op( - "RefInputFloatInput", [ref_t, nonref_t], [], - input_types=[dtypes.float32_ref, dtypes.float32], - name="op2") - self.assertProtoEquals( - "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'", - op2.node_def) - op3 = g.create_op("TwoFloatInputs", [ref_t, nonref_t], [], name="op3") - self.assertProtoEquals( - "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'", - op3.node_def) - - def testFinalized(self): - g = ops.Graph() - g.finalize() - with self.assertRaises(RuntimeError): - g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1") - - # Test unfinalize. - g._unsafe_unfinalize() - g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1") - - -# NOTE(skyewm): these cases test the private Graph._create_op_from_tf_operation -# method. Arguably we should only test the public APIs that depend on this -# method. However, this logic is complex and tricky, and it can be difficult to -# ascertain if we have adequate coverage (e.g. a graph may run successfully if -# the control flow context isn't set properly, but a more complicated use case -# that might not be obvious to test will fail). Thus we instead explicitly test -# the low-level behavior. -class CreateOpFromTFOperationTest(test_util.TensorFlowTestCase): - - @test_util.run_deprecated_v1 - def testBasic(self): - g = ops.Graph() - with g.as_default(): - x = test_ops.int_output() - c_op = ops._create_c_op( - g, ops._NodeDef("IntInputIntOutput", "myop"), [x], []) - op = g._create_op_from_tf_operation(c_op) - - self.assertEqual(op.name, "myop") - self.assertEqual(op.type, "IntInputIntOutput") - self.assertEqual(len(op.outputs), 1) - self.assertEqual(op.outputs[0].shape, tensor_shape.unknown_shape()) - self.assertEqual(list(op.inputs), [x]) - self.assertEqual(op.control_inputs, []) - self.assertEqual(op.graph, g) - self.assertEqual(x.consumers(), [op]) - self.assertIsNotNone(op.traceback) - self.assertEqual(g.get_operation_by_name("myop"), op) - self.assertEqual(g.get_tensor_by_name("myop:0"), op.outputs[0]) - - def testShape(self): - g = ops.Graph() - with g.as_default(): - x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) - c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), [x], []) - op = g._create_op_from_tf_operation(c_op) - - self.assertEqual(op.name, "myop") - self.assertEqual(op.type, "Identity") - self.assertEqual(len(op.outputs), 1) - self.assertEqual(op.outputs[0].shape, tensor_shape.matrix(2, 3)) - - def testUniqueName(self): - g = ops.Graph() - with g.as_default(): - c_op = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop"), [], []) - c_op2 = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop_1"), [], []) - op = g._create_op_from_tf_operation(c_op) - op2 = g._create_op_from_tf_operation(c_op2) - - # Create ops with same names as op1 and op2. We expect the new names to be - # uniquified. - op3 = test_ops.int_output(name="myop").op - op4 = test_ops.int_output(name="myop_1").op - - self.assertEqual(op.name, "myop") - self.assertEqual(op2.name, "myop_1") - self.assertEqual(op3.name, "myop_2") - self.assertEqual(op4.name, "myop_1_1") - - @test_util.run_v1_only("b/120545219") - def testCond(self): - g = ops.Graph() - with g.as_default(): - x = test_ops.int_output() - - def true_fn(): - ops._create_c_op(ops.get_default_graph(), - ops._NodeDef("IntInput", "cond/myop"), [x], []) - new_ops = g._add_new_tf_operations() - self.assertEqual(len(new_ops), 1) - return x - - control_flow_ops.cond(x < 10, true_fn, lambda: x) - - op = g.get_operation_by_name("cond/myop") - self.assertIsNotNone(op) - self.assertEqual(op.name, "cond/myop") - self.assertEqual(op.type, "IntInput") - self.assertEqual(op.outputs, []) - op_input = op.inputs[0].op - self.assertEqual(op_input.type, "Switch") - self.assertEqual(op_input.inputs[0], x) - self.assertEqual(op.graph, g) - # pylint: disable=protected-access - self.assertIsNotNone(op._get_control_flow_context()) - self.assertEqual(op._get_control_flow_context().name, - "cond/cond_text") - # pylint: enable=protected-access - - @test_util.run_v1_only("b/120545219") - def testWhileLoop(self): - g = ops.Graph() - with g.as_default(): - x = test_ops.int_output() - - def body(i): - ops._create_c_op(ops.get_default_graph(), - ops._NodeDef("IntInput", "myloop/myop"), [x], []) - new_ops = g._add_new_tf_operations() - self.assertEqual(len(new_ops), 1) - return i - - control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") - - op = g.get_operation_by_name("myloop/myop") - self.assertIsNotNone(op) - self.assertEqual(op.name, "myloop/myop") - self.assertEqual(op.type, "IntInput") - self.assertEqual(op.outputs, []) - op_input = op.inputs[0].op - self.assertEqual(op_input.type, "Enter") - self.assertEqual(list(op_input.inputs), [x]) - self.assertEqual(op.graph, g) - # pylint: disable=protected-access - self.assertIsNotNone(op._get_control_flow_context()) - self.assertEqual(op._get_control_flow_context().name, - "myloop/while_context") - # pylint: enable=protected-access - - @test_util.run_v1_only("b/120545219") - def testWhileLoopWithInternalControlDep(self): - g = ops.Graph() - with g.as_default(): - x = test_ops.int_output() - - def body(i): - c = constant_op.constant(1.0, name="c") - ops._create_c_op(ops.get_default_graph(), - ops._NodeDef("IntInput", "myloop/myop"), [x], []) - with ops.control_dependencies([c]): - new_ops = g._add_new_tf_operations() - self.assertEqual(len(new_ops), 1) - return i - - control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") - - op = g.get_operation_by_name("myloop/myop") - self.assertIsNotNone(op) - c = g.get_operation_by_name("myloop/c") - self.assertIsNotNone(c) - # Internal control dep is preserved - self.assertEqual(op.control_inputs, [c]) - - @test_util.run_v1_only("b/120545219") - def testWhileLoopWithExternalControlDep(self): - g = ops.Graph() - with g.as_default(): - x = test_ops.int_output() - c = constant_op.constant(1.0) - - def body(i): - ops._create_c_op(ops.get_default_graph(), - ops._NodeDef("IntInput", "myloop/myop"), [x], []) - with ops.control_dependencies([c]): - new_ops = g._add_new_tf_operations() - self.assertEqual(len(new_ops), 1) - return i - - control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") - - op = g.get_operation_by_name("myloop/myop") - self.assertIsNotNone(op) - # External control dep is removed and replaced with internal control dep - self.assertNotEqual(op.control_inputs[0], c.op) - self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context()) - - -class ApplyOpTest(test_util.TensorFlowTestCase): - - def testNodeDefArgs(self): - g = ops.Graph() - t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1") - with g.device("/device:GPU:0"): - t2 = _apply_op( - g, "TwoIntOutputs", [], [dtypes.int32, dtypes.int32], name="myop2") - t3 = _apply_op( - g, - "Foo1", [t1, t2[1], t2[0]], [dtypes.float32, dtypes.int32], - name="myop3") - self.assertTrue(isinstance(t1, ops.Tensor)) - self.assertTrue(isinstance(t2, list)) - self.assertTrue(isinstance(t3, list)) - self.assertTrue(isinstance(t3[0], ops.Tensor)) - self.assertEqual("myop1", t1._as_node_def_input()) - self.assertEqual("myop2", t2[0]._as_node_def_input()) - self.assertEqual("myop2:1", t2[1]._as_node_def_input()) - self.assertEqual("myop3", t3[0]._as_node_def_input()) - # Validate that we got the right ops as well - self.assertProtoEquals("name:'myop1' op:'FloatOutput'", t1.op.node_def) - self.assertProtoEquals( - "name:'myop2' op:'TwoIntOutputs' device:'/device:GPU:0'", - t2[0].op.node_def) - self.assertProtoEquals( - "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo1'", - t3[0].op.node_def) - - def testReferenceInput(self): - g = ops.Graph() - ref_t, nonref_t = _apply_op( - g, "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32], - name="op1") - self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", - ref_t.op.node_def) - # NOTE(mrry): Must specify input_types to preserve ref-typed input. - out_2 = _apply_op( - g, - "RefInputFloatInputIntOutput", [ref_t, nonref_t], [dtypes.int32], - input_types=[dtypes.float32_ref, dtypes.float32], - name="op2") - self.assertProtoEquals( - "op:'RefInputFloatInputIntOutput' name:'op2' input:'op1' input:'op1:1'", - out_2.op.node_def) - out_3 = _apply_op( - g, "TwoFloatInputsIntOutput", [ref_t, nonref_t], [dtypes.int32], - name="op3") - self.assertProtoEquals( - "op:'TwoFloatInputsIntOutput' name:'op3' input:'op1' input:'op1:1'", - out_3.op.node_def) - - -class NameStackTest(test_util.TensorFlowTestCase): - - def testBasics(self): - g = ops.Graph() - self.assertEqual("foo", g.unique_name("foo", mark_as_used=False)) - self.assertEqual("foo", g.unique_name("foo", mark_as_used=False)) - self.assertEqual("foo", g.unique_name("foo")) - self.assertEqual("foo_1", g.unique_name("foo", mark_as_used=False)) - self.assertEqual("foo_1", g.unique_name("foo")) - self.assertEqual("foo_2", g.unique_name("foo", mark_as_used=False)) - self.assertEqual("foo_2", g.unique_name("foo")) - self.assertEqual("foo_1_1", g.unique_name("foo_1", mark_as_used=False)) - self.assertEqual("foo_1_1", g.unique_name("foo_1")) - self.assertEqual("foo_1_2", g.unique_name("foo_1", mark_as_used=False)) - self.assertEqual("foo_1_2", g.unique_name("foo_1")) - self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2", mark_as_used=False)) - self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2")) - with g.name_scope("bar"): - self.assertEqual("bar/foo", g.unique_name("foo", mark_as_used=False)) - self.assertEqual("bar/foo", g.unique_name("foo")) - self.assertEqual("bar/foo_1", g.unique_name("foo", mark_as_used=False)) - self.assertEqual("bar/foo_1", g.unique_name("foo")) - with g.name_scope(None): - self.assertEqual("foo_3", g.unique_name("foo", mark_as_used=False)) - self.assertEqual("foo_3", g.unique_name("foo")) - with g.name_scope("baz"): - self.assertEqual( - "bar/baz/foo", g.unique_name( - "foo", mark_as_used=False)) - self.assertEqual("bar/baz/foo", g.unique_name("foo")) - self.assertEqual( - "bar/baz/foo_1", g.unique_name( - "foo", mark_as_used=False)) - self.assertEqual("bar/baz/foo_1", g.unique_name("foo")) - with g.name_scope("baz"): - self.assertEqual( - "bar/baz_1/foo", g.unique_name( - "foo", mark_as_used=False)) - self.assertEqual("bar/baz_1/foo", g.unique_name("foo")) - self.assertEqual( - "bar/baz_1/foo_1", g.unique_name( - "foo", mark_as_used=False)) - self.assertEqual("bar/baz_1/foo_1", g.unique_name("foo")) - with g.name_scope("quux"): - self.assertEqual("quux/foo", g.unique_name("foo", mark_as_used=False)) - self.assertEqual("quux/foo", g.unique_name("foo")) - with g.name_scope("bar"): - with g.name_scope("baz"): - self.assertEqual( - "bar_1/baz/foo", g.unique_name( - "foo", mark_as_used=False)) - self.assertEqual("bar_1/baz/foo", g.unique_name("foo")) - self.assertEqual("foo_4", g.unique_name("foo", mark_as_used=False)) - self.assertEqual("foo_4", g.unique_name("foo")) - self.assertEqual("bar_2", g.unique_name("bar", mark_as_used=False)) - self.assertEqual("bar_2", g.unique_name("bar")) - - @test_util.run_deprecated_v1 - def testNameAndVariableScope(self): - with self.cached_session() as sess: - with sess.graph.name_scope("l0"): - with variable_scope.variable_scope("l1"): - with sess.graph.name_scope("l1") as scope: - self.assertEqual("l0/l1/l1/", scope) - self.assertEqual( - "l0/l1/l1/foo", - sess.graph.unique_name( - "foo", mark_as_used=False)) - self.assertEqual("l0/l1/l1/foo", sess.graph.unique_name("foo")) - with sess.graph.name_scope("l2") as scope: - self.assertEqual("l0/l1/l2/", scope) - self.assertEqual( - "l0/l1/l2/foo", - sess.graph.unique_name( - "foo", mark_as_used=False)) - self.assertEqual("l0/l1/l2/foo", sess.graph.unique_name("foo")) - - def testOutOfOrderUniqueName(self): - g = ops.Graph() - self.assertEqual("foo_2", g.unique_name("foo_2")) - self.assertEqual("foo", g.unique_name("foo")) - self.assertEqual("foo_1", g.unique_name("foo")) - self.assertEqual("foo_3", g.unique_name("foo")) - - def testUniqueNameCaseInsensitivity(self): - g = ops.Graph() - self.assertEqual("foo", g.unique_name("foo")) - self.assertEqual("Foo_1", g.unique_name("Foo")) - with g.name_scope("bar"): - self.assertEqual("bar/foo", g.unique_name("foo")) - with g.name_scope("Bar"): - self.assertEqual("Bar_1/foo", g.unique_name("foo")) - - def testInvalidNameRaisesError(self): - g = ops.Graph() - with g.name_scope(""): # Should not raise - pass - with g.name_scope("foo/"): # Should not raise - with g.name_scope("_bar"): # Should not raise - pass - with self.assertRaises(ValueError): - with g.name_scope("foo:0"): - pass - with self.assertRaises(ValueError): - with g.name_scope("_bar"): - pass - - -class NameTest(test_util.TensorFlowTestCase): - - def testGenerateName(self): - g = ops.Graph() - op0 = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32]) - self.assertEqual("TwoFloatOutputs", op0.name) - self.assertEqual("TwoFloatOutputs:0", op0.outputs[0].name) - self.assertEqual("TwoFloatOutputs:1", op0.outputs[1].name) - - op1 = g.create_op("FloatOutput", [], [dtypes.float32]) - self.assertEqual("FloatOutput", op1.name) - self.assertEqual("FloatOutput:0", op1.outputs[0].name) - - op2 = g.create_op("FloatOutput", [], [dtypes.float32]) - self.assertEqual("FloatOutput_1", op2.name) - self.assertEqual("FloatOutput_1:0", op2.outputs[0].name) - - op3 = g.create_op("FloatOutput", [], [dtypes.float32], name="my_op") - self.assertEqual("my_op", op3.name) - self.assertEqual("my_op:0", op3.outputs[0].name) - - def testNameScope(self): - g = ops.Graph() - - with g.name_scope("foo") as foo: - self.assertEqual("foo/", foo) - with g.name_scope("foo2") as foo2: - self.assertEqual("foo/foo2/", foo2) - with g.name_scope(None) as empty1: - self.assertEqual("", empty1) - with g.name_scope("foo3") as foo3: - self.assertEqual("foo3/", foo3) - with g.name_scope("") as empty2: - self.assertEqual("", empty2) - - self.assertEqual("FloatOutput", - g.create_op("FloatOutput", [], [dtypes.float32]).name) - with g.name_scope("bar") as scope: - self.assertEqual("bar/FloatOutput", - g.create_op("FloatOutput", [], [dtypes.float32]).name) - self.assertEqual("bar/FloatOutput_1", - g.create_op("FloatOutput", [], [dtypes.float32]).name) - # If you use the value from "with .. as", that values is used as-is. - self.assertEqual( - "bar", g.create_op( - "FloatOutput", [], [dtypes.float32], name=scope).name) - with g.name_scope("baz") as scope: - with g.name_scope("quux"): - self.assertEqual("baz/quux/FloatOutput", - g.create_op("FloatOutput", [], [dtypes.float32]).name) - # If you use the value from the enclosing "with .. as", nothing is pushed. - with g.name_scope(scope): - self.assertEqual("baz/FloatOutput", - g.create_op("FloatOutput", [], [dtypes.float32]).name) - self.assertEqual( - "baz", g.create_op( - "FloatOutput", [], [dtypes.float32], name=scope).name) - self.assertEqual( - "trailing", - g.create_op( - "FloatOutput", [], [dtypes.float32], name="trailing/").name) - with g.name_scope("bar"): - self.assertEqual("bar_1/FloatOutput", - g.create_op("FloatOutput", [], [dtypes.float32]).name) - with g.name_scope("bar/"): - self.assertEqual("bar/FloatOutput_2", - g.create_op("FloatOutput", [], [dtypes.float32]).name) - - -class DeviceTest(test_util.TensorFlowTestCase): - - def testNoDevice(self): - g = ops.Graph() - op = g.create_op("FloatOutput", [], [dtypes.float32]) - self.assertDeviceEqual(None, op.device) - gd = g.as_graph_def() - self.assertProtoEqualsVersion(""" - node { name: "FloatOutput" op: "FloatOutput" } - """, gd) - - def testEagerBackingDevice(self): - with context.eager_mode(): - with ops.device("/device:CPU:0"): - t = constant_op.constant(1.0) - self.assertRegexpMatches(t.device, "/device:CPU:0") - self.assertRegexpMatches(t.backing_device, "/device:CPU:0") - - def testDevicePartialString(self): - g = ops.Graph() - with g.device("/job:worker/replica:2"): - g.create_op("FloatOutput", [], [dtypes.float32]) - gd = g.as_graph_def() - self.assertProtoEqualsVersion(""" - node { name: "FloatOutput" op: "FloatOutput" - device: "/job:worker/replica:2" } - """, gd) - - def testDeviceFull(self): - g = ops.Graph() - with g.device( - pydev.DeviceSpec( - job="worker", replica=2, task=0, device_type="CPU", - device_index=3)): - g.create_op("FloatOutput", [], [dtypes.float32]) - gd = g.as_graph_def() - self.assertProtoEqualsVersion(""" - node { name: "FloatOutput" op: "FloatOutput" - device: "/job:worker/replica:2/task:0/device:CPU:3" } - """, gd) - - def testNesting(self): - g = ops.Graph() - with g.device("/job:worker/replica:2"): - g.create_op("FloatOutput", [], [dtypes.float32]) - with g.device("/job:worker/replica:3/task:0"): - g.create_op("FloatOutput", [], [dtypes.float32]) - g.create_op("FloatOutput", [], [dtypes.float32]) - gd = g.as_graph_def() - self.assertProtoEqualsVersion(""" - node { name: "FloatOutput" op: "FloatOutput" - device: "/job:worker/replica:2" } - node { name: "FloatOutput_1" op: "FloatOutput" - device: "/job:worker/replica:3/task:0" } - node { name: "FloatOutput_2" op: "FloatOutput" - device: "/job:worker/replica:2" } - """, gd) - - def testNestingString(self): - g = ops.Graph() - with g.device("/job:worker/replica:2"): - g.create_op("FloatOutput", [], [dtypes.float32]) - with g.device("/job:worker/replica:3/task:0"): - g.create_op("FloatOutput", [], [dtypes.float32]) - g.create_op("FloatOutput", [], [dtypes.float32]) - gd = g.as_graph_def() - self.assertProtoEqualsVersion(""" - node { name: "FloatOutput" op: "FloatOutput" - device: "/job:worker/replica:2" } - node { name: "FloatOutput_1" op: "FloatOutput" - device: "/job:worker/replica:3/task:0" } - node { name: "FloatOutput_2" op: "FloatOutput" - device: "/job:worker/replica:2" } - """, gd) - - def testNestingOverrideGpuCpu(self): - g = ops.Graph() - with g.device("/job:worker/replica:2/device:CPU:1"): - g.create_op("FloatOutput", [], [dtypes.float32]) - with g.device("/job:worker/replica:2/device:GPU:2"): - g.create_op("FloatOutput", [], [dtypes.float32]) - g.create_op("FloatOutput", [], [dtypes.float32]) - gd = g.as_graph_def() - self.assertProtoEqualsVersion(""" - node { name: "FloatOutput" op: "FloatOutput" - device: "/job:worker/replica:2/device:CPU:1" } - node { name: "FloatOutput_1" op: "FloatOutput" - device: "/job:worker/replica:2/device:GPU:2" } - node { name: "FloatOutput_2" op: "FloatOutput" - device: "/job:worker/replica:2/device:CPU:1" } - """, gd) - - def testNestingWithMergeDeviceFunction(self): - g = ops.Graph() - - with g.device(pydev.merge_device("/device:GPU:0")): - g.create_op("FloatOutput", [], [dtypes.float32]) - with g.device(pydev.merge_device("/job:worker")): - g.create_op("FloatOutput", [], [dtypes.float32]) - with g.device(pydev.merge_device("/device:CPU:0")): - g.create_op("FloatOutput", [], [dtypes.float32]) - with g.device(pydev.merge_device("/job:ps")): - g.create_op("FloatOutput", [], [dtypes.float32]) - with g.device(pydev.merge_device(None)): - g.create_op("FloatOutput", [], [dtypes.float32]) - - gd = g.as_graph_def() - self.assertProtoEqualsVersion(""" - node { name: "FloatOutput" op: "FloatOutput" - device: "/device:GPU:0" } - node { name: "FloatOutput_1" op: "FloatOutput" - device: "/job:worker/device:GPU:0" } - node { name: "FloatOutput_2" op: "FloatOutput" - device: "/job:worker/device:CPU:0" } - node { name: "FloatOutput_3" op: "FloatOutput" - device: "/job:ps/device:CPU:0" } - node { name: "FloatOutput_4" op: "FloatOutput" - device: "/job:ps/device:CPU:0" } - """, gd) - - def testNestingWithDeviceStrings(self): - g = ops.Graph() - - with g.device("/device:GPU:0"): - g.create_op("FloatOutput", [], [dtypes.float32]) - with g.device("/job:worker"): - g.create_op("FloatOutput", [], [dtypes.float32]) - with g.device("/device:CPU:0"): - g.create_op("FloatOutput", [], [dtypes.float32]) - with g.device("/job:ps"): - g.create_op("FloatOutput", [], [dtypes.float32]) - with g.device(""): - g.create_op("FloatOutput", [], [dtypes.float32]) - - gd = g.as_graph_def() - self.assertProtoEqualsVersion(""" - node { name: "FloatOutput" op: "FloatOutput" - device: "/device:GPU:0" } - node { name: "FloatOutput_1" op: "FloatOutput" - device: "/job:worker/device:GPU:0" } - node { name: "FloatOutput_2" op: "FloatOutput" - device: "/job:worker/device:CPU:0" } - node { name: "FloatOutput_3" op: "FloatOutput" - device: "/job:ps/device:CPU:0" } - node { name: "FloatOutput_4" op: "FloatOutput" - device: "/job:ps/device:CPU:0" } - """, gd) - - def testNestingWithDeviceStringWildcard(self): - g = ops.Graph() - - with g.device("/device:GPU:7"): - g.create_op("FloatOutput", [], [dtypes.float32]) - with g.device("/device:GPU:*"): - g.create_op("FloatOutput", [], [dtypes.float32]) - - with g.device("/device:CPU:*"): - g.create_op("FloatOutput", [], [dtypes.float32]) - with g.device("/device:CPU:5"): - g.create_op("FloatOutput", [], [dtypes.float32]) - - gd = g.as_graph_def() - self.assertProtoEqualsVersion(""" - node { name: "FloatOutput" op: "FloatOutput" - device: "/device:GPU:7" } - node { name: "FloatOutput_1" op: "FloatOutput" - device: "/device:GPU:7" } - node { name: "FloatOutput_2" op: "FloatOutput" - device: "/device:CPU:*" } - node { name: "FloatOutput_3" op: "FloatOutput" - device: "/device:CPU:5" } - """, gd) - - def testNoneClearsDefault(self): - g = ops.Graph() - with g.device("/job:worker/replica:2/device:CPU:1"): - g.create_op("FloatOutput", [], [dtypes.float32]) - with g.device(None): - g.create_op("FloatOutput", [], [dtypes.float32]) - g.create_op("FloatOutput", [], [dtypes.float32]) - gd = g.as_graph_def() - self.assertProtoEqualsVersion(""" - node { name: "FloatOutput" op: "FloatOutput" - device: "/job:worker/replica:2/device:CPU:1" } - node { name: "FloatOutput_1" op: "FloatOutput" } - node { name: "FloatOutput_2" op: "FloatOutput" - device: "/job:worker/replica:2/device:CPU:1" } - """, gd) - - def testNoneIgnoresOuterDeviceFunction(self): - g = ops.Graph() - with g.device(lambda op: "/job:worker/replica:2/device:CPU:1"): - g.create_op("FloatOutput", [], [dtypes.float32]) - with g.device(None): - g.create_op("FloatOutput", [], [dtypes.float32]) - g.create_op("FloatOutput", [], [dtypes.float32]) - gd = g.as_graph_def() - self.assertProtoEqualsVersion(""" - node { name: "FloatOutput" op: "FloatOutput" - device: "/job:worker/replica:2/device:CPU:1" } - node { name: "FloatOutput_1" op: "FloatOutput" } - node { name: "FloatOutput_2" op: "FloatOutput" - device: "/job:worker/replica:2/device:CPU:1" } - """, gd) - - def _overwritingDeviceFunction(self, unused_op): - # This device function unconditionally overwrites the device of ops. - # - # NOTE(mrry): Writing device functions like this is not - # recommended. Instead, in most cases you should use - # `pydev.merge_device("/job:ps")` or simply `"/job:ps"` as the - # argument to `tf.device()` and the device component will be merged in. - return "/job:overwrite" - - def testOverwritingBehavior(self): - g = ops.Graph() - with g.device(self._overwritingDeviceFunction): - g.create_op("FloatOutput", [], [dtypes.float32]) - with g.device("/job:ps"): # Will be overwritten. - g.create_op("FloatOutput", [], [dtypes.float32]) - with g.device(pydev.merge_device("/job:ps")): # Will be overwritten. - g.create_op("FloatOutput", [], [dtypes.float32]) - with g.device(None): # Disables overwriting device function - with g.device("/job:ps"): - g.create_op("FloatOutput", [], [dtypes.float32]) - with g.device(None): # Disables overwriting device function - with g.device(pydev.merge_device("/job:ps")): - g.create_op("FloatOutput", [], [dtypes.float32]) - gd = g.as_graph_def() - self.assertProtoEqualsVersion(""" - node { name: "FloatOutput" op: "FloatOutput" - device: "/job:overwrite" } - node { name: "FloatOutput_1" op: "FloatOutput" - device: "/job:overwrite" } - node { name: "FloatOutput_2" op: "FloatOutput" - device: "/job:overwrite" } - node { name: "FloatOutput_3" op: "FloatOutput" - device: "/job:ps" } - node { name: "FloatOutput_4" op: "FloatOutput" - device: "/job:ps" } - """, gd) - - -class MultithreadedGraphStateTest(test_util.TensorFlowTestCase): - - class TestThread(threading.Thread): - - def __init__(self, graph, replica_id): - super(MultithreadedGraphStateTest.TestThread, self).__init__() - self._graph = graph - self._replica_id = replica_id - # This thread sets this event when it mutated the graph. The caller can - # wait for that. - self.has_mutated_graph = threading.Event() - # This thread waits for when it should continue. The caller can set this - # event. - self.should_continue = threading.Event() - - def run(self): - # Mutate a graph's stack, then set `has_mutated_graph`, then wait for - # `should_continue`, then add an op to the graph affected by the graph's - # stack. - raise NotImplementedError("must be implemented in descendants") - - def testDeviceFunctionStack(self): - - class DeviceSettingThread(self.TestThread): - - def run(self): - with g.device("/job:worker/replica:{}".format(self._replica_id)): - self.has_mutated_graph.set() - self.should_continue.wait() - self.should_continue.clear() - g.create_op( - "FloatOutput", [], [dtypes.float32], - name="FloatOutput_{}".format(self._replica_id)) - - g = ops.Graph() - # If `switch_to_thread` isn't called, then device placement of the ops - # below is not deterministic. - g.switch_to_thread_local() - threads = [DeviceSettingThread(g, i) for i in range(3)] - for t in threads: - t.start() - t.has_mutated_graph.wait() - t.has_mutated_graph.clear() - for t in threads: - t.should_continue.set() - t.join() - - gd = g.as_graph_def() - self.assertProtoEqualsVersion(""" - node { name: "FloatOutput_0" op: "FloatOutput" - device: "/job:worker/replica:0" } - node { name: "FloatOutput_1" op: "FloatOutput" - device: "/job:worker/replica:1" } - node { name: "FloatOutput_2" op: "FloatOutput" - device: "/job:worker/replica:2" } - """, gd) - - def testColocateWith(self): - - class ColocatingThread(self.TestThread): - - def __init__(self, graph, replica_id, op_to_colocate_with): - super(ColocatingThread, self).__init__(graph, replica_id) - self._op_to_colocate_with = op_to_colocate_with - - def run(self): - with g.colocate_with(self._op_to_colocate_with): - self.has_mutated_graph.set() - self.should_continue.wait() - self.should_continue.clear() - g.create_op( - "FloatOutput", [], [dtypes.float32], - name="FloatOutput_{}".format(self._replica_id)) - - g = ops.Graph() - ops_to_colocate_with = [] - for i in range(3): - with g.device("/job:worker/replica:{}".format(i)): - ops_to_colocate_with.append( - g.create_op( - "FloatOutput", [], [dtypes.float32], - name="ColocateWithMe_{}".format(i))) - - # If `switch_to_thread` isn't called, then `device` and `attr` values for - # the ops below are not deterministic. - g.switch_to_thread_local() - threads = [ - ColocatingThread(g, i, ops_to_colocate_with[i]) for i in range(3) - ] - for t in threads: - t.start() - t.has_mutated_graph.wait() - t.has_mutated_graph.clear() - for t in threads: - t.should_continue.set() - t.join() - - gd = g.as_graph_def() - self.assertProtoEqualsVersion(""" - node { name: "ColocateWithMe_0" op: "FloatOutput" - device: "/job:worker/replica:0" } - node { name: "ColocateWithMe_1" op: "FloatOutput" - device: "/job:worker/replica:1" } - node { name: "ColocateWithMe_2" op: "FloatOutput" - device: "/job:worker/replica:2" } - node { name: "FloatOutput_0" op: "FloatOutput" - device: "/job:worker/replica:0" - attr { key: "_class" - value { list { - s: "loc:@ColocateWithMe_0"}}}} - node { name: "FloatOutput_1" op: "FloatOutput" - device: "/job:worker/replica:1" - attr { key: "_class" - value { list { - s: "loc:@ColocateWithMe_1"}}}} - node { name: "FloatOutput_2" op: "FloatOutput" - device: "/job:worker/replica:2" - attr { key: "_class" - value { list { - s: "loc:@ColocateWithMe_2"}}}} - """, gd) - - def testControlDependencies(self): - - class DependingThread(self.TestThread): - - def __init__(self, graph, replica_id, dependency_op): - super(DependingThread, self).__init__(graph, replica_id) - self._dependency_op = dependency_op - - def run(self): - with g.control_dependencies([self._dependency_op]): - self.has_mutated_graph.set() - self.should_continue.wait() - self.should_continue.clear() - g.create_op( - "FloatOutput", [], [dtypes.float32], - name="FloatOutput_{}".format(self._replica_id)) - - g = ops.Graph() - dependency_ops = [] - for i in range(3): - dependency_ops.append( - g.create_op( - "FloatOutput", [], [dtypes.float32], - name="ColocateWithMe_{}".format(i))) - - # If `switch_to_thread` isn't called, then `input` values for the ops below - # are not deterministic. - g.switch_to_thread_local() - threads = [DependingThread(g, i, dependency_ops[i]) for i in range(3)] - for t in threads: - t.start() - t.has_mutated_graph.wait() - t.has_mutated_graph.clear() - for t in threads: - t.should_continue.set() - t.join() - - gd = g.as_graph_def() - self.assertProtoEqualsVersion(""" - node { name: "ColocateWithMe_0" op: "FloatOutput" } - node { name: "ColocateWithMe_1" op: "FloatOutput" } - node { name: "ColocateWithMe_2" op: "FloatOutput" } - node { name: "FloatOutput_0" op: "FloatOutput" - input: "^ColocateWithMe_0" } - node { name: "FloatOutput_1" op: "FloatOutput" - input: "^ColocateWithMe_1" } - node { name: "FloatOutput_2" op: "FloatOutput" - input: "^ColocateWithMe_2" } - """, gd) - - def testNameStack(self): - - class NameSettingThread(self.TestThread): - - def run(self): - with g.name_scope("foo"): - op1 = g.create_op("FloatOutput", [], [dtypes.float32]) - self.has_mutated_graph.set() - self.should_continue.wait() - self.should_continue.clear() - op2 = g.create_op("FloatOutput", [], [dtypes.float32]) - self.result = (op1, op2) - - g = ops.Graph() - threads = [NameSettingThread(g, i) for i in range(3)] - for t in threads: - t.start() - t.has_mutated_graph.wait() - t.has_mutated_graph.clear() - - for t in threads: - t.should_continue.set() - t.join() - - suffixes = ["", "_1", "_2"] - for t, s in zip(threads, suffixes): - self.assertEquals("foo" + s + "/FloatOutput", t.result[0].name) - self.assertEquals("foo" + s + "/FloatOutput_1", t.result[1].name) - - -class ObjectWithName(object): - - def __init__(self, name): - self._name = name - - @property - def name(self): - return self._name - - -class CollectionTest(test_util.TensorFlowTestCase): - - def test_get_collections(self): - g = ops.Graph() - self.assertSequenceEqual(g.collections, []) - g.add_to_collection("key", 12) - g.add_to_collection("key", 15) - self.assertSequenceEqual(g.collections, ["key"]) - g.add_to_collection("other", "foo") - self.assertSequenceEqual(sorted(g.collections), ["key", "other"]) - - def test_add_to_collection(self): - g = ops.Graph() - g.add_to_collection("key", 12) - g.add_to_collection("other", "foo") - g.add_to_collection("key", 34) - - # Note that only blank1 is returned. - g.add_to_collection("blah", 27) - blank1 = ObjectWithName("prefix/foo") - g.add_to_collection("blah", blank1) - blank2 = ObjectWithName("junk/foo") - g.add_to_collection("blah", blank2) - - self.assertEqual([12, 34], g.get_collection("key")) - self.assertEqual([], g.get_collection("nothing")) - self.assertEqual([27, blank1, blank2], g.get_collection("blah")) - self.assertEqual([blank1], g.get_collection("blah", "prefix")) - self.assertEqual([blank1], g.get_collection("blah", ".*x")) - - # Make sure that get_collection() returns a first-level - # copy of the collection, while get_collection_ref() returns - # the original list. - other_collection_snapshot = g.get_collection("other") - other_collection_ref = g.get_collection_ref("other") - self.assertEqual(["foo"], other_collection_snapshot) - self.assertEqual(["foo"], other_collection_ref) - g.add_to_collection("other", "bar") - self.assertEqual(["foo"], other_collection_snapshot) - self.assertEqual(["foo", "bar"], other_collection_ref) - self.assertEqual(["foo", "bar"], g.get_collection("other")) - self.assertTrue(other_collection_ref is g.get_collection_ref("other")) - - # Verify that getting an empty collection ref returns a modifiable list. - empty_coll_ref = g.get_collection_ref("empty") - self.assertEqual([], empty_coll_ref) - empty_coll = g.get_collection("empty") - self.assertEqual([], empty_coll) - self.assertFalse(empty_coll is empty_coll_ref) - empty_coll_ref2 = g.get_collection_ref("empty") - self.assertTrue(empty_coll_ref2 is empty_coll_ref) - # Add to the collection. - empty_coll_ref.append("something") - self.assertEqual(["something"], empty_coll_ref) - self.assertEqual(["something"], empty_coll_ref2) - self.assertEqual([], empty_coll) - self.assertEqual(["something"], g.get_collection("empty")) - empty_coll_ref3 = g.get_collection_ref("empty") - self.assertTrue(empty_coll_ref3 is empty_coll_ref) - - def test_add_to_collections_uniquify(self): - g = ops.Graph() - g.add_to_collections([1, 2, 1], "key") - # Make sure "key" is not added twice - self.assertEqual(["key"], g.get_collection(1)) - - def test_add_to_collections_from_list(self): - g = ops.Graph() - g.add_to_collections(["abc", "123"], "key") - self.assertEqual(["key"], g.get_collection("abc")) - self.assertEqual(["key"], g.get_collection("123")) - - def test_add_to_collections_from_tuple(self): - g = ops.Graph() - g.add_to_collections(("abc", "123"), "key") - self.assertEqual(["key"], g.get_collection("abc")) - self.assertEqual(["key"], g.get_collection("123")) - - def test_add_to_collections_from_generator(self): - g = ops.Graph() - - def generator(): - yield "abc" - yield "123" - - g.add_to_collections(generator(), "key") - self.assertEqual(["key"], g.get_collection("abc")) - self.assertEqual(["key"], g.get_collection("123")) - - def test_add_to_collections_from_set(self): - g = ops.Graph() - g.add_to_collections(set(["abc", "123"]), "key") - self.assertEqual(["key"], g.get_collection("abc")) - self.assertEqual(["key"], g.get_collection("123")) - - def test_add_to_collections_from_string(self): - g = ops.Graph() - g.add_to_collections("abc", "key") - self.assertEqual(["key"], g.get_collection("abc")) - - def test_default_graph(self): - with ops.Graph().as_default(): - ops.add_to_collection("key", 90) - ops.add_to_collection("key", 100) - # Collections are ordered. - self.assertEqual([90, 100], ops.get_collection("key")) - - def test_defun(self): - with context.eager_mode(): - - @eager_function.defun - def defun(): - ops.add_to_collection("int", 1) - ops.add_to_collection("tensor", constant_op.constant(2)) - - @eager_function.defun - def inner_defun(): - self.assertEqual(ops.get_collection("int"), [1]) - three = ops.get_collection("tensor")[0] + ops.get_collection("int")[0] - ops.add_to_collection("int", 2) - self.assertEqual(ops.get_collection("int"), [1, 2]) - ops.add_to_collection("foo", "bar") - self.assertEqual(ops.get_collection("foo"), ["bar"]) - return three - - self.assertEqual(ops.get_collection("int"), [1]) - three = inner_defun() - self.assertEqual(ops.get_collection("int"), [1]) - self.assertEqual(ops.get_collection("foo"), []) - return three - - three = defun() - self.assertEqual(three.numpy(), 3) - - -ops.NotDifferentiable("FloatOutput") - - -@ops.RegisterGradient("CopyOp") -def _CopyGrad(op, x_grad): # pylint: disable=invalid-name - _ = op - return x_grad - - -@ops.RegisterGradient("copy_override") -def _CopyOverrideGrad(op, x_grad): # pylint: disable=invalid-name - _ = op - return x_grad - - -class RegistrationTest(test_util.TensorFlowTestCase): - - @test_util.run_deprecated_v1 - def testRegisterGradients(self): - x = test_ops.float_output() - y = test_ops.copy_op(x) - fn = ops.get_gradient_function(y.op) - self.assertEqual(_CopyGrad, fn) - - def testOverrideGradients(self): - g = ops.Graph() - with g.as_default(): - x = test_ops.float_output() - with g.gradient_override_map({"CopyOp": "copy_override"}): - y = test_ops.copy_op(x) - fn = ops.get_gradient_function(y.op) - self.assertEqual(_CopyOverrideGrad, fn) - - def testNonExistentOverride(self): - g = ops.Graph() - with g.as_default(): - x = test_ops.float_output() - with g.gradient_override_map({"CopyOp": "unknown_override"}): - y = test_ops.copy_op(x) - with self.assertRaisesRegexp(LookupError, "unknown_override"): - ops.get_gradient_function(y.op) - - -class ComparisonTest(test_util.TensorFlowTestCase): - - def testMembershipAllowed(self): - g = ops.Graph() - t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1") - t2 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop2") - self.assertTrue(isinstance(t1, ops.Tensor)) - self.assertTrue(isinstance(t2, ops.Tensor)) - self.assertTrue(t1 in [t1]) - self.assertTrue(t1 not in [t2]) - - -class ControlDependenciesTest(test_util.TensorFlowTestCase): - - @test_util.run_deprecated_v1 - def testBasic(self): - g = ops.Graph() - with g.as_default(): - # Creating unregistered ops with _apply_op() doesn't work with the C API - # TODO(skyewm): address this more consistently. Possible solutions are - # to use registered ops in all tests, create a way to register ops in - # Python tests, or conditionally disable the op registration check in - # the C API. - a = constant_op.constant(1.0) - b = constant_op.constant(1.0) - with g.control_dependencies([a]): - c = constant_op.constant(1.0) - d = array_ops.identity(b) - e = array_ops.identity(c) - - self.assertEqual(c.op.control_inputs, [a.op]) - self.assertEqual(d.op.control_inputs, [a.op]) - # e should be dominated by c. - self.assertEqual(e.op.control_inputs, []) - - @test_util.run_in_graph_and_eager_modes - def testEager(self): - def future(): - future.calls += 1 - return constant_op.constant(2.0) - future.calls = 0 - - if context.executing_eagerly(): - a = constant_op.constant(1.0) - b = future - with ops.control_dependencies([a, b]): - c = constant_op.constant(3.0) - self.assertEqual(future.calls, 1) - else: - g = ops.Graph() - with g.as_default(): - a = constant_op.constant(1.0) - b = future() - with g.control_dependencies([a, b]): - c = constant_op.constant(3.0) - self.assertEqual(c.op.control_inputs, [a.op, b.op]) - self.assertEqual(future.calls, 1) - - def testBasicWithConversion(self): - g = ops.Graph() - a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - - class ConvertibleObj(object): - - def _as_graph_element(self): - return a - - with g.control_dependencies([ConvertibleObj()]): - c = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - - self.assertEqual(c.op.control_inputs, [a.op]) - - def testNested(self): - g = ops.Graph() - a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - - with g.control_dependencies([a_1, a_2, a_3, a_4]): - b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - - with g.control_dependencies([a_1]): - with g.control_dependencies([a_2]): - with g.control_dependencies([a_3]): - with g.control_dependencies([a_4]): - b_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - - self.assertItemsEqual([a_1.op, a_2.op, a_3.op, a_4.op], - b_1.op.control_inputs) - self.assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs) - - def testClear(self): - g = ops.Graph() - a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - - with g.control_dependencies([a_1]): - with g.control_dependencies([a_2]): - with g.control_dependencies(None): - with g.control_dependencies([a_3]): - with g.control_dependencies([a_4]): - # deps [a_3, a_4] - b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - # deps = [a_3] - b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - # deps back to None - b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - # deps back to [a_1, a_2] - b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - # deps back to [a_1] - b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - with g.control_dependencies(None): - # deps are None again - b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - - self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs) - self.assertItemsEqual([a_3.op], b_3.op.control_inputs) - self.assertItemsEqual([], b_none.op.control_inputs) - self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs) - self.assertItemsEqual([a_1.op], b_1.op.control_inputs) - self.assertItemsEqual([], b_none2.op.control_inputs) - - def testComplex(self): - g = ops.Graph() - - # Usage pattern: - # * Nodes a_i are constants defined at the outermost scope, and are used - # as control inputs for the ith nested scope. - # * Nodes b_i are defined as Mul(a_3, a_4) at each scope. - # * Nodes c_i are defined as Mul(a_1, b_1) at each scope. - # * Nodes d_i are defined as Mul(b_i, c_i) at each scope. - # * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1. - - a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - - with g.control_dependencies([a_1]): - b_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], - [dtypes.float32]) - c_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], - [dtypes.float32]) - d_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_1, c_1], - [dtypes.float32]) - e_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - with g.control_dependencies([a_2]): - b_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], - [dtypes.float32]) - c_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], - [dtypes.float32]) - d_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_2, c_2], - [dtypes.float32]) - e_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_1, e_1], - [dtypes.float32]) - with g.control_dependencies([a_3]): - b_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], - [dtypes.float32]) - c_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], - [dtypes.float32]) - d_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_3, c_3], - [dtypes.float32]) - e_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_2, e_2], - [dtypes.float32]) - with g.control_dependencies([a_4]): - b_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], - [dtypes.float32]) - c_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], - [dtypes.float32]) - d_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_4, c_4], - [dtypes.float32]) - e_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_3, e_3], - [dtypes.float32]) - - self.assertItemsEqual([a_1.op], b_1.op.control_inputs) - self.assertItemsEqual([a_1.op, a_2.op], b_2.op.control_inputs) - self.assertItemsEqual([a_1.op, a_2.op], b_3.op.control_inputs) - self.assertItemsEqual([a_1.op, a_2.op], b_4.op.control_inputs) - - self.assertItemsEqual([], c_1.op.control_inputs) - self.assertItemsEqual([a_2.op], c_2.op.control_inputs) - self.assertItemsEqual([a_2.op, a_3.op], c_3.op.control_inputs) - self.assertItemsEqual([a_2.op, a_3.op, a_4.op], c_4.op.control_inputs) - - self.assertItemsEqual([], d_1.op.control_inputs) - self.assertItemsEqual([], d_2.op.control_inputs) - self.assertItemsEqual([], d_3.op.control_inputs) - self.assertItemsEqual([], d_4.op.control_inputs) - - self.assertItemsEqual([a_1.op], e_1.op.control_inputs) - self.assertItemsEqual([a_2.op], e_2.op.control_inputs) - self.assertItemsEqual([a_3.op], e_3.op.control_inputs) - self.assertItemsEqual([a_4.op], e_4.op.control_inputs) - - def testRepeatedDependency(self): - g = ops.Graph() - a = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32]) - a_0, a_1 = a.outputs - with g.control_dependencies([a_0]): - b = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - with g.control_dependencies([a_1]): - c = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - - self.assertEqual(b.op.control_inputs, [a]) - self.assertEqual(c.op.control_inputs, [a]) - - def testNoControlDependencyWithDataDependency(self): - g = ops.Graph() - a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - with g.control_dependencies([a]): - b = _apply_op(g, "Identity", [a], [dtypes.float32]) - - self.assertEqual(b.op.control_inputs, []) - - -class OpScopeTest(test_util.TensorFlowTestCase): - - @test_util.run_in_graph_and_eager_modes - def testNames(self): - with ops.name_scope("foo") as foo: - self.assertEqual("foo/", foo) - with ops.name_scope("foo2") as foo2: - self.assertEqual("foo/foo2/", foo2) - with ops.name_scope(None) as empty1: - self.assertEqual("", empty1) - with ops.name_scope("foo3") as foo3: - self.assertEqual("foo3/", foo3) - with ops.name_scope("") as empty2: - self.assertEqual("", empty2) - with ops.name_scope("foo/") as outer_foo: - self.assertEqual("foo/", outer_foo) - with ops.name_scope("") as empty3: - self.assertEqual("", empty3) - with ops.name_scope("foo4") as foo4: - self.assertEqual("foo/foo4/", foo4) - with ops.name_scope("foo5//") as foo5: - self.assertEqual("foo5//", foo5) - with ops.name_scope("foo6") as foo6: - self.assertEqual("foo5//foo6/", foo6) - with ops.name_scope("/") as foo7: - self.assertEqual("/", foo7) - with ops.name_scope("//") as foo8: - self.assertEqual("//", foo8) - with ops.name_scope("a//b/c") as foo9: - self.assertEqual("foo/a//b/c/", foo9) - with ops.name_scope("a//b/c") as foo10: - self.assertEqual("a//b/c/", foo10) - - @test_util.run_in_graph_and_eager_modes - def testEagerDefaultScopeName(self): - with ops.name_scope(None, "default") as scope: - self.assertEqual(scope, "default/") - with ops.name_scope(None, "default2") as scope2: - self.assertEqual(scope2, "default/default2/") - - @test_util.run_deprecated_v1 - def testNoScopeName(self): - g0 = ops.Graph() - values = [ - g0.create_op("A", [], [dtypes.float32]), - g0.create_op("B", [], [dtypes.float32]) - ] - with self.assertRaises(ValueError): - with ops.name_scope(None, values=values): - pass - with self.assertRaises(ValueError): - with ops.name_scope(None, None, values): - pass - - @test_util.run_deprecated_v1 - def testEmptyScopeName(self): - g0 = ops.Graph() - a = g0.create_op("A", [], [dtypes.float32]) - b = g0.create_op("B", [], [dtypes.float32]) - with ops.name_scope("", values=[a, b]) as scope: - self.assertEqual("", scope) - self.assertEqual(g0, ops.get_default_graph()) - with ops.name_scope("", "my_default_scope", [a, b]) as scope: - self.assertEqual("", scope) - self.assertEqual(g0, ops.get_default_graph()) - - @test_util.run_deprecated_v1 - def testDefaultScopeName(self): - g0 = ops.Graph() - a = g0.create_op("A", [], [dtypes.float32]) - b = g0.create_op("B", [], [dtypes.float32]) - scope_name = "my_scope" - default_scope_name = "my_default_scope" - with ops.name_scope(scope_name, default_scope_name, [a, b]) as scope: - self.assertEqual("%s/" % scope_name, scope) - self.assertEqual(g0, ops.get_default_graph()) - with ops.name_scope(None, default_scope_name, [a, b]) as scope: - self.assertEqual("%s/" % default_scope_name, scope) - self.assertEqual(g0, ops.get_default_graph()) - - def _testGraphElements(self, graph_elements): - scope_name = "my_scope" - with ops.name_scope(scope_name, values=graph_elements) as scope: - self.assertEqual("%s/" % scope_name, scope) - self.assertEqual(graph_elements[0].graph, ops.get_default_graph()) - g1 = ops.Graph() - a = g1.create_op("A", [], [dtypes.float32]) - with self.assertRaises(ValueError): - with ops.name_scope(scope_name, values=graph_elements + [a]): - pass - - @test_util.run_deprecated_v1 - def testTensor(self): - g0 = ops.Graph() - a = g0.create_op("A", [], [dtypes.float32]) - b = g0.create_op("B", [], [dtypes.float32]) - self._testGraphElements([a, b]) - - @test_util.run_deprecated_v1 - def testSparseTensor(self): - g0 = ops.Graph() - a = g0.create_op("A", [], [dtypes.float32]) - b = g0.create_op("B", [], [dtypes.float32]) - sparse = sparse_tensor.SparseTensor( - _apply_op(g0, "Int64Output", [], [dtypes.int64]), - _apply_op(g0, "FloatOutput", [], [dtypes.float32]), - _apply_op(g0, "Int64Output", [], [dtypes.int64])) - self._testGraphElements([a, sparse, b]) - - @test_util.run_deprecated_v1 - def testVariable(self): - g0 = ops.Graph() - with g0.as_default(): - variable = variables.Variable([1.0]) - a = g0.create_op("A", [], [dtypes.float32]) - b = g0.create_op("B", [], [dtypes.float32]) - self._testGraphElements([a, variable, b]) - - -class InitScopeTest(test_util.TensorFlowTestCase): - - def testClearsControlDependencies(self): - g = ops.Graph() - a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - - with g.as_default(): - with g.control_dependencies([a_1]): - with g.control_dependencies([a_2]): - with ops.init_scope(): - with g.control_dependencies([a_3]): - with g.control_dependencies([a_4]): - # deps [a_3, a_4] - b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - # deps = [a_3] - b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - # deps back to None - b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - # deps back to [a_1, a_2] - b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - # deps back to [a_1] - b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - with ops.init_scope(): - # deps are None again - b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - - self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs) - self.assertItemsEqual([a_3.op], b_3.op.control_inputs) - self.assertItemsEqual([], b_none.op.control_inputs) - self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs) - self.assertItemsEqual([a_1.op], b_1.op.control_inputs) - self.assertItemsEqual([], b_none2.op.control_inputs) - - def testLiftsOpsFromFunctions(self): - g0 = ops.Graph() - g1 = ops.Graph() - g1._building_function = True # pylint: disable=protected-access - g2 = ops.Graph() - g2._building_function = True # pylint: disable=protected-access - - with g0.as_default(): - with g1.as_default(): - with g2.as_default(): - with ops.init_scope(): - _ = constant_op.constant(1.0) - - self.assertEqual(len(g2.get_operations()), 0) - self.assertEqual(len(g1.get_operations()), 0) - self.assertEqual(len(g0.get_operations()), 1) - - def testPreservesDevices(self): - g0 = ops.Graph() - with g0.as_default(), ops.device("CPU:0"): - g1 = ops.Graph() - g1._building_function = True # pylint: disable=protected-access - with g1.as_default(), ops.device("GPU:0"): - with ops.init_scope(): - # init_scope should preserve device set under `g1`. - on_gpu = constant_op.constant(1.0) - self.assertEqual(on_gpu.device, "/device:GPU:0") - still_on_gpu = constant_op.constant(1.0) - self.assertEqual(still_on_gpu.device, "/device:GPU:0") - on_cpu = constant_op.constant(1.0) - self.assertEqual(on_cpu.device, "/device:CPU:0") - - def testComposes(self): - g0 = ops.Graph() - g1 = ops.Graph() - g1._building_function = True # pylint: disable=protected-access - g2 = ops.Graph() - g2._building_function = True # pylint: disable=protected-access - g3 = ops.Graph() - g3._building_function = False # pylint: disable=protected-access - - with g0.as_default(): - with g1.as_default(): - with ops.init_scope(): - # This op should be lifted into g0. - _ = constant_op.constant(1.0) - self.assertIs(g0, ops.get_default_graph()) - self.assertEqual(len(g2.get_operations()), 0) - self.assertEqual(len(g1.get_operations()), 0) - self.assertEqual(len(g0.get_operations()), 1) - with g2.as_default(): - with ops.init_scope(): - # This op should be lifted into g0. - _ = constant_op.constant(1.0) - self.assertIs(g0, ops.get_default_graph()) - with g3.as_default(): - with ops.init_scope(): - # This op should be lifted into g3, because g3 is not building a - # function. - _ = constant_op.constant(1.0) - self.assertIs(g3, ops.get_default_graph()) - - self.assertEqual(len(g3.get_operations()), 1) - self.assertEqual(len(g2.get_operations()), 0) - self.assertEqual(len(g1.get_operations()), 0) - self.assertEqual(len(g0.get_operations()), 2) - - def testEscapesToEagerContext(self): - g = ops.Graph() - g._building_function = True # pylint: disable=protected-access - with context.eager_mode(): - with context.graph_mode(): - with g.as_default(): - with ops.init_scope(): - # Because g is building a function, init_scope should - # escape out to the eager context. - self.assertTrue(context.executing_eagerly()) - # g should be reinstated as the default graph, and the - # graph context should be re-entered. - self.assertIs(g, ops.get_default_graph()) - self.assertFalse(context.executing_eagerly()) - - def testStaysInEagerWhenOnlyEagerContextActive(self): - with context.eager_mode(): - with ops.init_scope(): - self.assertTrue(context.eager_mode()) - self.assertTrue(context.eager_mode()) - - def testEscapesDefunWhenInEagerMode(self): - - def function_with_variables(): - with ops.init_scope(): - self.v = resource_variable_ops.ResourceVariable(3) - return self.v.assign_add(1) - - with context.eager_mode(): - # Each invocation of function_with_variables recreates a variable. - self.assertEqual(4, int(function_with_variables())) - self.assertEqual(4, int(function_with_variables())) - - compiled = eager_function.defun(function_with_variables) - # The init_scope in function_with_variables lifts the variable out - # of the graph function constructed by defun; hence, - # compiled now appears to be stateful. - self.assertEqual(4, int(compiled())) - self.assertEqual(5, int(compiled())) - - def testEscapesDefunWhenInGraphMode(self): - def function_with_variables(name): - with ops.init_scope(): - _ = variable_scope.get_variable(name, shape=(1,)) - - g = ops.Graph() - with g.as_default(): - with self.cached_session(): - # First ensure that graphs that are not building functions are - # not escaped. - function_with_variables("foo") - with self.assertRaisesRegexp(ValueError, - r"Variable foo already exists.*"): - # This will fail because reuse is not set to True. - function_with_variables("foo") - - compiled = eager_function.defun(function_with_variables) - compiled("bar") - self.assertEqual( - len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2) - - # The second call to `compiled` should not create variables: the - # init_scope has lifted the variable creation code out of the defun. - compiled("bar") - self.assertEqual( - len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2) - - def testEscapesNestedDefun(self): - - def inner_function(): - with ops.init_scope(): - self.v = resource_variable_ops.ResourceVariable(1) - return self.v.assign_add(2) - - def outer_function(inner=None): - with ops.init_scope(): - self.v0 = resource_variable_ops.ResourceVariable(0) - return self.v0.assign_add(1) + inner() - - with context.eager_mode(): - # Each invocation of outer_function recreates variables. - self.assertEqual(4, int(outer_function(inner=inner_function))) - self.assertEqual(4, int(outer_function(inner=inner_function))) - - compiled_inner = eager_function.defun(inner_function) - compiled_outer = eager_function.defun(outer_function) - # The init_scope lifts variables out of the graph functions - # constructed by defun; hence, compiled_outer should now appear to be - # stateful. - self.assertEqual(4, int(compiled_outer(inner=compiled_inner))) - self.assertEqual(7, int(compiled_outer(inner=compiled_inner))) - - @test_util.run_v1_only("b/120545219") - def testFallsBackToGlobalGraphWhenAllGraphsAreBuildingFunctions(self): - with context.graph_mode(): - ops.reset_default_graph() - # This doesn't push anything onto the graph stack, but it does - # set the stack's global graph. - global_graph = ops.get_default_graph() - fn_graph = ops.Graph() - - # pylint: disable=protected-access - fn_graph._building_function = True - self.assertEqual(len(ops._default_graph_stack.stack), 0) - with fn_graph.as_default(): - self.assertEqual(len(ops._default_graph_stack.stack), 1) - with ops.init_scope(): - self.assertGreater(len(ops._default_graph_stack.stack), 1) - dummy = constant_op.constant(1.0) - self.assertEqual(len(ops._default_graph_stack.stack), 1) - # Note that the global graph is _not_ on the graph stack. - self.assertEqual(len(ops._default_graph_stack.stack), 0) - # Ensure that `dummy` was added to the global graph. - self.assertEqual(global_graph, dummy.graph) - # pylint: enable=protected-access - - def testInstallsDefaultGraphWhenGraphStackIsEmptyInGraphMode(self): - with context.graph_mode(): - # pylint: disable=protected-access - self.assertEqual(len(ops._default_graph_stack.stack), 0) - with ops.init_scope(): - self.assertGreater(len(ops._default_graph_stack.stack), 0) - self.assertEqual(len(ops._default_graph_stack.stack), 0) - # pylint: enable=protected-access - - def testPreservesNameScopeInGraphConstruction(self): - with ops.Graph().as_default(): - function_graph = ops.Graph() - with function_graph.as_default(): - with ops.name_scope("inner"), ops.init_scope(): - self.assertEqual(ops.get_name_scope(), "inner") - self.assertEqual(ops.get_name_scope(), "") - - def testEnteringGraphFromEagerIsSticky(self): - with context.eager_mode(): - g = ops.Graph() - with g.as_default(): - with ops.init_scope(): - self.assertFalse(context.executing_eagerly()) - self.assertEqual(g, ops.get_default_graph()) - - def testMixGraphEager(self): - with context.eager_mode(): - c = constant_op.constant(1.0) - with ops.Graph().as_default(): - with self.assertRaisesRegexp( - RuntimeError, "Attempting to capture an EagerTensor"): - math_ops.add(c, c) - c2 = constant_op.constant(2.0) - with self.assertRaisesRegexp( - TypeError, "contains objects other than 'EagerTensor'"): - math_ops.add(c2, c2) - - def testPreservesNameScopeInEagerExecution(self): - with context.eager_mode(): - def foo(): - with ops.name_scope("inner"), ops.init_scope(): - if context.executing_eagerly(): - # A trailing slash is always appended when eager execution is - # enabled. - self.assertEqual(context.context().scope_name, "inner/") - else: - self.assertEqual(ops.get_name_scope(), "inner") - - foo() - self.assertEqual(ops.get_name_scope(), "") - foo_compiled = eager_function.defun(foo) - foo_compiled() - self.assertEqual(ops.get_name_scope(), "") - - def testExecutingEagerlyOutsideFunctions(self): - - @eager_function.defun - def f(): - return ops.executing_eagerly_outside_functions() - - with context.eager_mode(): - self.assertTrue(ops.executing_eagerly_outside_functions()) - self.assertTrue(f()) - g = ops.Graph() - with g.as_default(): - self.assertFalse(ops.executing_eagerly_outside_functions()) - - -class GraphTest(test_util.TensorFlowTestCase): - - def setUp(self): - ops.reset_default_graph() - - def _AssertDefault(self, expected): - self.assertIs(expected, ops.get_default_graph()) - - def testResetDefaultGraphNesting(self): - g0 = ops.Graph() - with self.assertRaises(AssertionError): - with g0.as_default(): - ops.reset_default_graph() - - def testGraphContextManagerCancelsEager(self): - with context.eager_mode(): - with ops.Graph().as_default(): - self.assertFalse(context.executing_eagerly()) - - def testGraphContextManager(self): - g0 = ops.Graph() - with g0.as_default() as g1: - self.assertIs(g0, g1) - - def testDefaultGraph(self): - orig = ops.get_default_graph() - self._AssertDefault(orig) - g0 = ops.Graph() - self._AssertDefault(orig) - context_manager_0 = g0.as_default() - self._AssertDefault(orig) - with context_manager_0 as g0: - self._AssertDefault(g0) - with ops.Graph().as_default() as g1: - self._AssertDefault(g1) - self._AssertDefault(g0) - self._AssertDefault(orig) - - def testPreventFeeding(self): - g = ops.Graph() - a = constant_op.constant(2.0) - self.assertTrue(g.is_feedable(a)) - g.prevent_feeding(a) - self.assertFalse(g.is_feedable(a)) - - @test_util.run_deprecated_v1 - def testPreventFetching(self): - g = ops.Graph() - a = constant_op.constant(2.0) - self.assertTrue(g.is_fetchable(a)) - g.prevent_fetching(a.op) - self.assertFalse(g.is_fetchable(a)) - - def testAsGraphElementConversions(self): - - class ConvertibleObj(object): - - def _as_graph_element(self): - return "FloatOutput:0" - - class NonConvertibleObj(object): - - pass - - g = ops.Graph() - a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - self.assertEqual(a, g.as_graph_element(ConvertibleObj())) - with self.assertRaises(TypeError): - g.as_graph_element(NonConvertibleObj()) - - # Regression test against creating custom __del__ functions in classes - # involved in cyclic references, e.g. Graph and Operation. (Python won't gc - # cycles that require calling a __del__ method, because the __del__ method can - # theoretically increase the object's refcount to "save" it from gc, and any - # already-deleted objects in the cycle would have be to restored.) - def testGarbageCollected(self): - # Create a graph we can delete and a weak reference to monitor if it's gc'd - g = ops.Graph() - g_ref = weakref.ref(g) - # Create some ops - with g.as_default(): - a = constant_op.constant(2.0) - b = constant_op.constant(3.0) - c = math_ops.add(a, b) - # Create a session we can delete - with session.Session(graph=g) as sess: - self.evaluate(c) - # Delete all references and trigger gc - del g - del a - del b - del c - del sess - gc.collect() - self.assertIsNone(g_ref()) - - def testRunnableAfterInvalidShape(self): - with ops.Graph().as_default(): - with self.assertRaises(ValueError): - math_ops.add([1, 2], [1, 2, 3]) - a = constant_op.constant(1) - with session.Session() as sess: - self.evaluate(a) - - def testRunnableAfterInvalidShapeWithKernelLabelMap(self): - g = ops.Graph() - with g.as_default(): - with g._kernel_label_map({"KernelLabelRequired": "overload_1"}): - with self.assertRaises(ValueError): - test_ops.kernel_label_required(1) - a = constant_op.constant(1) - with session.Session() as sess: - self.evaluate(a) - - -class AttrScopeTest(test_util.TensorFlowTestCase): - - def _get_test_attrs(self): - x = control_flow_ops.no_op() - try: - a = compat.as_text(x.get_attr("_A")) - except ValueError: - a = None - try: - b = compat.as_text(x.get_attr("_B")) - except ValueError: - b = None - return (a, b) - - @test_util.run_deprecated_v1 - def testNoLabel(self): - with self.cached_session(): - self.assertAllEqual((None, None), self._get_test_attrs()) - - @test_util.run_deprecated_v1 - def testLabelMap(self): - with self.cached_session() as sess: - a1 = self._get_test_attrs() - with sess.graph._attr_scope({ - "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("foo")) - }): - a2 = self._get_test_attrs() - with sess.graph._attr_scope({ - "_A": None, - "_B": attr_value_pb2.AttrValue(s=compat.as_bytes("bar")) - }): - a3 = self._get_test_attrs() - with sess.graph._attr_scope({ - "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("baz")) - }): - a4 = self._get_test_attrs() - a5 = self._get_test_attrs() - a6 = self._get_test_attrs() - a7 = self._get_test_attrs() - - self.assertAllEqual((None, None), a1) - self.assertAllEqual(("foo", None), a2) - self.assertAllEqual((None, "bar"), a3) - self.assertAllEqual(("baz", "bar"), a4) - self.assertAllEqual((None, "bar"), a5) - self.assertAllEqual(("foo", None), a6) - self.assertAllEqual((None, None), a7) - - -ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape) - - -class KernelLabelTest(test_util.TensorFlowTestCase): - - @test_util.run_deprecated_v1 - def testNoLabel(self): - with self.cached_session(): - self.assertAllEqual(b"My label is: default", - test_ops.kernel_label().eval()) - - @test_util.run_deprecated_v1 - def testLabelMap(self): - with self.cached_session() as sess: - default_1 = test_ops.kernel_label() - # pylint: disable=protected-access - with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}): - overload_1_1 = test_ops.kernel_label() - with sess.graph._kernel_label_map({"KernelLabel": "overload_2"}): - overload_2 = test_ops.kernel_label() - with sess.graph._kernel_label_map({"KernelLabel": ""}): - default_2 = test_ops.kernel_label() - overload_1_2 = test_ops.kernel_label() - # pylint: enable=protected-access - default_3 = test_ops.kernel_label() - - self.assertAllEqual(b"My label is: default", self.evaluate(default_1)) - self.assertAllEqual(b"My label is: default", self.evaluate(default_2)) - self.assertAllEqual(b"My label is: default", self.evaluate(default_3)) - self.assertAllEqual(b"My label is: overload_1", - self.evaluate(overload_1_1)) - self.assertAllEqual(b"My label is: overload_1", - self.evaluate(overload_1_2)) - self.assertAllEqual(b"My label is: overload_2", self.evaluate(overload_2)) - - -class AsGraphDefTest(test_util.TensorFlowTestCase): - - def testGraphDefVersion(self): - """Test that the graphdef version is plumbed through to kernels.""" - with ops.Graph().as_default() as g: - version = g.graph_def_versions.producer - with self.session(graph=g): - v = test_ops.graph_def_version().eval() - self.assertEqual(version, v) - - def testAddShapes(self): - with ops.Graph().as_default() as g: - t1, t2, t3, t4, t5 = _apply_op(g, "FiveFloatOutputs", [], - [dtypes.float32] * 5) - t1.set_shape(None) - t2.set_shape([]) - t3.set_shape([None]) - t4.set_shape([43, 37]) - t5.set_shape([43, None]) - - b = constant_op.constant(1.0) # pylint: disable=unused-variable - - gd = g.as_graph_def(add_shapes=True) - self.assertProtoEqualsVersion(""" - node { name: "FiveFloatOutputs" op: "FiveFloatOutputs" - attr { - key: "_output_shapes" - value { - list { - shape { unknown_rank: true } - shape { } - shape { dim { size: -1 } } - shape { dim { size: 43 } dim { size: 37 } } - shape { dim { size: 43 } dim { size: -1 } } - } - } - } - } - node { name: "Const" op: "Const" - attr { - key: "_output_shapes" - value { - list { - shape { } - } - } - } - attr { - key: "dtype" - value { type: DT_FLOAT } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { } - float_val: 1.0 } } } } - """, gd) - - -@ops.RegisterStatistics("a", "flops") -def _calc_a_forward_flops(unused_graph, unused_node): - return ops.OpStats("flops", 20) - - -class StatisticsTest(test_util.TensorFlowTestCase): - - def testRegisteredNode(self): - graph = ops.Graph() - node = ops._NodeDef("a", "an_a") - flops = ops.get_stats_for_node_def(graph, node, "flops") - self.assertEqual(20, flops.value) - missing_stat = ops.get_stats_for_node_def(graph, node, "missing_stat") - self.assertEqual(None, missing_stat.value) - - def testUnregisteredNode(self): - graph = ops.Graph() - node = ops._NodeDef("b", "a_b") - weight_params = ops.get_stats_for_node_def(graph, node, "weight_params") - self.assertEqual(None, weight_params.value) - - def testAccumulateStatistics(self): - flops_total = ops.OpStats("flops") - self.assertEqual(None, flops_total.value) - second_flops = ops.OpStats("flops", 3) - flops_total += second_flops - self.assertEqual(3, flops_total.value) - - -class DeviceStackTest(test_util.TensorFlowTestCase): - - @test_util.run_deprecated_v1 - def testBasicDeviceAssignmentMetadata(self): - - def device_func(unused_op): - return "/cpu:*" - - const_zero = constant_op.constant([0.0], name="zero") - with ops.device("/cpu"): - const_one = constant_op.constant([1.0], name="one") - with ops.device("/cpu:0"): - const_two = constant_op.constant([2.0], name="two") - with ops.device(device_func): - const_three = constant_op.constant(3.0, name="three") - - self.assertEqual(0, len(const_zero.op._device_assignments)) - - one_list = const_one.op._device_assignments - self.assertEqual(1, len(one_list)) - self.assertEqual("/cpu", one_list[0].obj) - self.assertEqual("ops_test.py", os.path.basename(one_list[0].filename)) - - two_list = const_two.op._device_assignments - self.assertEqual(2, len(two_list)) - devices = [t.obj for t in two_list] - self.assertEqual(set(["/cpu", "/cpu:0"]), set(devices)) - - three_list = const_three.op._device_assignments - self.assertEqual(1, len(three_list)) - func_description = three_list[0].obj - expected_regex = r"device_func<.*ops_test.py, [0-9]+" - self.assertRegexpMatches(func_description, expected_regex) - - @test_util.run_deprecated_v1 - def testDeviceAssignmentMetadataForGraphDeviceAndTfDeviceFunctions(self): - - with ops.device("/cpu"): - const_one = constant_op.constant([1.0], name="one") - with ops.get_default_graph().device("/cpu"): - const_two = constant_op.constant([2.0], name="two") - - one_metadata = const_one.op._device_assignments[0] - two_metadata = const_two.op._device_assignments[0] - - # Verify both types of device assignment return the right stack info. - self.assertRegexpMatches("ops_test.py", - os.path.basename(one_metadata.filename)) - self.assertEqual(one_metadata.filename, two_metadata.filename) - self.assertEqual(one_metadata.lineno + 2, two_metadata.lineno) - - -class ColocationGroupTest(test_util.TensorFlowTestCase): - - @test_util.run_deprecated_v1 - def testBasic(self): - a = constant_op.constant([2.0], name="a") - with ops.colocate_with(a.op): - b = constant_op.constant(3.0) - c = constant_op.constant(4.0) - self.assertEqual([b"loc:@a"], a.op.colocation_groups()) - self.assertEqual([b"loc:@a"], b.op.colocation_groups()) - with self.assertRaises(ValueError): - c.op.get_attr("_class") - - @test_util.run_deprecated_v1 - def testBasicColocationMetadata(self): - const_two = constant_op.constant([2.0], name="two") - with ops.colocate_with(const_two.op): - const_three = constant_op.constant(3.0, name="three") - locations_dict = const_three.op._colocation_dict - self.assertIn("two", locations_dict) - metadata = locations_dict["two"] - self.assertIsNone(metadata.obj) - # Check that this test's filename is recorded as the file containing the - # colocation statement. - self.assertEqual("ops_test.py", os.path.basename(metadata.filename)) - - @test_util.run_deprecated_v1 - def testColocationDeviceInteraction(self): - with ops.device("/cpu:0"): - with ops.device("/device:GPU:0"): - a = constant_op.constant([2.0], name="a") - with ops.colocate_with(a.op): - # 'b' is created in the scope of /cpu:0, but it is - # colocated with 'a', which is on '/device:GPU:0'. colocate_with - # overrides devices because it is a stronger constraint. - b = constant_op.constant(3.0) - self.assertEqual([b"loc:@a"], b.op.colocation_groups()) - self.assertEqual(a.op.device, b.op.device) - - @test_util.run_deprecated_v1 - def testColocationCanonicalization(self): - with ops.device("/device:GPU:0"): - _ = constant_op.constant(2.0) - with ops.device(lambda op: "/device:GPU:0"): - b = constant_op.constant(3.0) - with ops.get_default_graph().colocate_with(b): - with ops.device("/device:GPU:0"): - c = constant_op.constant(4.0) - - # A's device will be /device:GPU:0 - # B's device will be /device:GPU:0 - # C's device will be /device:GPU:0 because it - # inherits B's device name, after canonicalizing the names. - self.assertEqual(b.op.device, c.op.device) - - @test_util.run_deprecated_v1 - def testLocationOverrides(self): - with ops.device("/cpu:0"): - with ops.device("/device:GPU:0"): - a = constant_op.constant([2.0], name="a") - # Note that this colocation is "redundant", since we are - # within the scope of "/device:GPU:0". However, we would like to - # preserve in the GraphDef that these two ops should be - # colocated in a portable way. - with ops.colocate_with(a.op): - b = constant_op.constant(3.0) - c = constant_op.constant(4.0) - d = constant_op.constant(5.0) - - self.assertEqual([b"loc:@a"], b.op.colocation_groups()) - self.assertEqual("/device:GPU:0", a.op.device) - self.assertEqual(a.op.device, b.op.device) - - # Test that device function stack is restored. - self.assertEqual("/device:GPU:0", c.op.device) - self.assertEqual("/device:CPU:0", d.op.device) - - @test_util.run_deprecated_v1 - def testNestedColocateWith(self): - a = constant_op.constant([2.0], name="a") - with ops.colocate_with(a.op): - b = constant_op.constant(3.0) - with ops.colocate_with(b.op): - c = constant_op.constant(4.0) - self.assertEqual([b"loc:@a"], b.op.colocation_groups()) - self.assertEqual([b"loc:@a"], c.op.colocation_groups()) - - @test_util.run_deprecated_v1 - def testMultiColocationGroups(self): - a = constant_op.constant([2.0], name="a") - b = constant_op.constant(3.0, name="b") - with ops.colocate_with(a.op): - with ops.colocate_with(b.op): - c = constant_op.constant(4.0) - self.assertEqual(set([b"loc:@a", b"loc:@b"]), set(c.op.colocation_groups())) - - @test_util.run_deprecated_v1 - def testColocationIgnoreStack(self): - a = constant_op.constant([2.0], name="a") - b = constant_op.constant(3.0, name="b") - with ops.colocate_with(a.op): - with ops.colocate_with(b.op, ignore_existing=True): - c = constant_op.constant(4.0) - self.assertEqual(set([b"loc:@b"]), set(c.op.colocation_groups())) - - @test_util.run_deprecated_v1 - def testColocateWithReset(self): - a = constant_op.constant([2.0], name="a") - with ops.colocate_with(a.op): - b = constant_op.constant(3.0, name="b") - with ops.colocate_with(None, ignore_existing=True): - c = constant_op.constant(4.0, name="c") - self.assertEqual([b"loc:@a"], b.op.colocation_groups()) - self.assertEqual([b"loc:@c"], c.op.colocation_groups()) - - @test_util.run_deprecated_v1 - def testColocateWithInitialNoneThenNested(self): - a = constant_op.constant([2.0], name="a") - with ops.colocate_with(a.op): - with ops.colocate_with(None, ignore_existing=True): - b = constant_op.constant(3.0, name="b") - with ops.colocate_with(b.op): - c = constant_op.constant(4.0, name="c") - self.assertEqual([b"loc:@b"], b.op.colocation_groups()) - self.assertEqual([b"loc:@b"], c.op.colocation_groups()) - - @test_util.run_deprecated_v1 - def testColocateVariables(self): - a = variables.Variable([2.0], name="a") - with ops.colocate_with(a.op): - b = variables.Variable([3.0], name="b") - self.assertEqual([b"loc:@a"], b.op.colocation_groups()) - - -class DeprecatedTest(test_util.TensorFlowTestCase): - - def testSuccess(self): - with ops.Graph().as_default() as g: - test_util.set_producer_version(g, 7) - old = test_ops.old() - with self.session(graph=g): - old.run() - - def _error(self): - return ((r"Op Old is not available in GraphDef version %d\. " - r"It has been removed in version 8\. For reasons\.") % - versions.GRAPH_DEF_VERSION) - - def testGraphConstructionFail(self): - with ops.Graph().as_default(): - with self.assertRaisesRegexp(NotImplementedError, self._error()): - test_ops.old() - - -class DenseTensorLikeTypeTest(test_util.TensorFlowTestCase): - - def testSuccess(self): - op = ops.Operation( - ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]) - t = op.outputs[0] - self.assertTrue(ops.is_dense_tensor_like(t)) - - v = variables.Variable([17]) - self.assertTrue(ops.is_dense_tensor_like(v)) - - class BadClassNoName(object): - pass - - class BadClassBadName(object): - - def name(self): - pass - - class BadClassNoDtype(object): - - @property - def name(self): - pass - - class BadClassBadDtype(object): - - @property - def name(self): - pass - - def dtype(self): - pass - - def testBadClass(self): - with self.assertRaisesRegexp(TypeError, "`name`"): - ops.register_dense_tensor_like_type( - DenseTensorLikeTypeTest.BadClassNoName) - with self.assertRaisesRegexp(TypeError, "`name`"): - ops.register_dense_tensor_like_type( - DenseTensorLikeTypeTest.BadClassBadName) - with self.assertRaisesRegexp(TypeError, "`dtype`"): - ops.register_dense_tensor_like_type( - DenseTensorLikeTypeTest.BadClassNoDtype) - with self.assertRaisesRegexp(TypeError, "`dtype`"): - ops.register_dense_tensor_like_type( - DenseTensorLikeTypeTest.BadClassBadDtype) - - -class NameScopeTest(test_util.TensorFlowTestCase): - - def testStripAndPrependScope(self): - strs = [ - "hidden1/hidden1/weights", # Same prefix. Should strip. - "hidden1///hidden1/weights", # Extra "/". Should strip. - "^hidden1/hidden1/weights", # Same prefix. Should strip. - "loc:@hidden1/hidden1/weights", # Same prefix. Should strip. - "hhidden1/hidden1/weights", # Different prefix. Should keep. - "hidden1" - ] # Not a prefix. Should keep. - expected_striped = [ - "hidden1/weights", "hidden1/weights", "^hidden1/weights", - "loc:@hidden1/weights", "hhidden1/hidden1/weights", "hidden1" - ] - expected_prepended = [ - "hidden2/hidden1/weights", "hidden2/hidden1/weights", - "^hidden2/hidden1/weights", "loc:@hidden2/hidden1/weights", - "hidden2/hhidden1/hidden1/weights", "hidden2/hidden1" - ] - name_scope_to_strip = "hidden1" - name_scope_to_add = "hidden2" - for es, ep, s in zip(expected_striped, expected_prepended, strs): - striped = ops.strip_name_scope(s, name_scope_to_strip) - self.assertEqual(es, striped) - self.assertEqual(ep, ops.prepend_name_scope(striped, name_scope_to_add)) - - def testGetNameScope(self): - with ops.Graph().as_default() as g: - with ops.name_scope("scope1"): - with ops.name_scope("scope2"): - with ops.name_scope("scope3"): - self.assertEqual("scope1/scope2/scope3", g.get_name_scope()) - self.assertEqual("scope1/scope2", g.get_name_scope()) - self.assertEqual("scope1", g.get_name_scope()) - self.assertEqual("", g.get_name_scope()) - - def testTwoGraphs(self): - - def f(): - g1 = ops.Graph() - g2 = ops.Graph() - with g1.as_default(): - with g2.as_default(): - with ops.name_scope("_"): - pass - - self.assertRaisesRegexp(ValueError, "'_' is not a valid scope name", f) - - -class TracebackTest(test_util.TensorFlowTestCase): - - @test_util.run_deprecated_v1 - def testTracebackWithStartLines(self): - with self.cached_session() as sess: - a = constant_op.constant(2.0) - sess.run( - a, - options=config_pb2.RunOptions( - trace_level=config_pb2.RunOptions.FULL_TRACE)) - self.assertTrue(sess.graph.get_operations()) - - # Tests that traceback_with_start_lines is the same as traceback - # but includes one more element at the end. - for op in sess.graph.get_operations(): - self.assertEquals(len(op.traceback), len(op.traceback_with_start_lines)) - for frame, frame_with_start_line in zip( - op.traceback, op.traceback_with_start_lines): - self.assertEquals(5, len(frame_with_start_line)) - self.assertEquals(frame, frame_with_start_line[:-1]) - - -class EnableEagerExecutionTest(test_util.TensorFlowTestCase): - - @test_util.run_v1_only("b/120545219") - def testBadArgumentsToEnableEagerExecution(self): - with self.assertRaisesRegexp(TypeError, "config must be a tf.ConfigProto"): - ops.enable_eager_execution(context.DEVICE_PLACEMENT_SILENT) - with self.assertRaisesRegexp(ValueError, "device_policy must be one of"): - c = config_pb2.ConfigProto() - ops.enable_eager_execution(c, c) - with self.assertRaisesRegexp(ValueError, "execution_mode must be one of"): - c = config_pb2.ConfigProto() - ops.enable_eager_execution(c, execution_mode=c) - - -if __name__ == "__main__": - googletest.main() diff --git a/test/TensorFlowNET.UnitTest/python/train_saver.py b/test/TensorFlowNET.UnitTest/python/train_saver.py deleted file mode 100644 index 47ffd6a1..00000000 --- a/test/TensorFlowNET.UnitTest/python/train_saver.py +++ /dev/null @@ -1,26 +0,0 @@ - -import tensorflow as tf - -# Create some variables. -v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer) -v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer) - -inc_v1 = v1.assign(v1+1) -dec_v2 = v2.assign(v2-1) - -# Add an op to initialize the variables. -init_op = tf.global_variables_initializer() - -# Add ops to save and restore all the variables. -saver = tf.train.Saver() - -# Later, launch the model, initialize the variables, do some work, and save the -# variables to disk. -with tf.Session() as sess: - sess.run(init_op) - # Do some work with the model. - inc_v1.op.run() - dec_v2.op.run() - # Save the variables to disk. - save_path = saver.save(sess, "/tmp/model.ckpt") - print("Model saved in path: %s" % save_path)