| @@ -17,6 +17,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Text", "src\Tens | |||||
| EndProject | EndProject | ||||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Recommenders", "src\TensorFlowNET.Recommenders\Tensorflow.Recommenders.csproj", "{F17AAECB-960A-4E18-A270-BAD776F0E55B}" | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Recommenders", "src\TensorFlowNET.Recommenders\Tensorflow.Recommenders.csproj", "{F17AAECB-960A-4E18-A270-BAD776F0E55B}" | ||||
| EndProject | EndProject | ||||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Native.UnitTest", "test\TensorFlowNET.Native.UnitTest\Tensorflow.Native.UnitTest.csproj", "{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}" | |||||
| EndProject | |||||
| Global | Global | ||||
| GlobalSection(SolutionConfigurationPlatforms) = preSolution | GlobalSection(SolutionConfigurationPlatforms) = preSolution | ||||
| Debug|Any CPU = Debug|Any CPU | Debug|Any CPU = Debug|Any CPU | ||||
| @@ -107,8 +109,8 @@ Global | |||||
| {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x86.Build.0 = Release|Any CPU | {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x86.Build.0 = Release|Any CPU | ||||
| {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | ||||
| {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|Any CPU.Build.0 = Debug|Any CPU | {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|Any CPU.Build.0 = Debug|Any CPU | ||||
| {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x64.ActiveCfg = Debug|Any CPU | |||||
| {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x64.Build.0 = Debug|Any CPU | |||||
| {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x64.ActiveCfg = Debug|x64 | |||||
| {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x64.Build.0 = Debug|x64 | |||||
| {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x86.ActiveCfg = Debug|Any CPU | {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x86.ActiveCfg = Debug|Any CPU | ||||
| {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x86.Build.0 = Debug|Any CPU | {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x86.Build.0 = Debug|Any CPU | ||||
| {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | ||||
| @@ -155,8 +157,8 @@ Global | |||||
| {49D71826-C03D-4FA7-9BAC-22C1327E65CF}.Release|x86.Build.0 = Release|Any CPU | {49D71826-C03D-4FA7-9BAC-22C1327E65CF}.Release|x86.Build.0 = Release|Any CPU | ||||
| {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | ||||
| {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|Any CPU.Build.0 = Debug|Any CPU | {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|Any CPU.Build.0 = Debug|Any CPU | ||||
| {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x64.ActiveCfg = Debug|Any CPU | |||||
| {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x64.Build.0 = Debug|Any CPU | |||||
| {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x64.ActiveCfg = Debug|x64 | |||||
| {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x64.Build.0 = Debug|x64 | |||||
| {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x86.ActiveCfg = Debug|Any CPU | {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x86.ActiveCfg = Debug|Any CPU | ||||
| {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x86.Build.0 = Debug|Any CPU | {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x86.Build.0 = Debug|Any CPU | ||||
| {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | ||||
| @@ -179,8 +181,8 @@ Global | |||||
| {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Release|x86.Build.0 = Release|Any CPU | {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Release|x86.Build.0 = Release|Any CPU | ||||
| {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | ||||
| {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|Any CPU.Build.0 = Debug|Any CPU | {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|Any CPU.Build.0 = Debug|Any CPU | ||||
| {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x64.ActiveCfg = Debug|Any CPU | |||||
| {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x64.Build.0 = Debug|Any CPU | |||||
| {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x64.ActiveCfg = Debug|x64 | |||||
| {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x64.Build.0 = Debug|x64 | |||||
| {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x86.ActiveCfg = Debug|Any CPU | {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x86.ActiveCfg = Debug|Any CPU | ||||
| {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x86.Build.0 = Debug|Any CPU | {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x86.Build.0 = Debug|Any CPU | ||||
| {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | ||||
| @@ -201,6 +203,30 @@ Global | |||||
| {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|x64.Build.0 = Release|Any CPU | {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|x64.Build.0 = Release|Any CPU | ||||
| {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|x86.ActiveCfg = Release|Any CPU | {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|x86.ActiveCfg = Release|Any CPU | ||||
| {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|x86.Build.0 = Release|Any CPU | {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|x86.Build.0 = Release|Any CPU | ||||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|x64.ActiveCfg = Debug|x64 | |||||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|x64.Build.0 = Debug|x64 | |||||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|x86.ActiveCfg = Debug|Any CPU | |||||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|x86.Build.0 = Debug|Any CPU | |||||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|x86.ActiveCfg = Debug|Any CPU | |||||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|x86.Build.0 = Debug|Any CPU | |||||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|Any CPU.ActiveCfg = Debug|Any CPU | |||||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|Any CPU.Build.0 = Debug|Any CPU | |||||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|x64.ActiveCfg = Debug|Any CPU | |||||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|x64.Build.0 = Debug|Any CPU | |||||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|x86.ActiveCfg = Debug|Any CPU | |||||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|x86.Build.0 = Debug|Any CPU | |||||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|Any CPU.Build.0 = Release|Any CPU | |||||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x64.ActiveCfg = Release|Any CPU | |||||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x64.Build.0 = Release|Any CPU | |||||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x86.ActiveCfg = Release|Any CPU | |||||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x86.Build.0 = Release|Any CPU | |||||
| EndGlobalSection | EndGlobalSection | ||||
| GlobalSection(SolutionProperties) = preSolution | GlobalSection(SolutionProperties) = preSolution | ||||
| HideSolutionNode = FALSE | HideSolutionNode = FALSE | ||||
| @@ -5,6 +5,7 @@ | |||||
| <TargetFramework>netcoreapp3.1</TargetFramework> | <TargetFramework>netcoreapp3.1</TargetFramework> | ||||
| <RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
| <AssemblyName>Tensorflow</AssemblyName> | <AssemblyName>Tensorflow</AssemblyName> | ||||
| <Platforms>AnyCPU;x64</Platforms> | |||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| @@ -28,17 +28,10 @@ namespace Tensorflow | |||||
| => ops.reset_default_graph(); | => ops.reset_default_graph(); | ||||
| public Graph get_default_graph() | public Graph get_default_graph() | ||||
| { | |||||
| return ops.get_default_graph(); | |||||
| } | |||||
| => ops.get_default_graph(); | |||||
| /// <summary> | |||||
| /// Equivalent to <see cref="get_default_graph"/> but does not create a new graph if it there is none. | |||||
| /// </summary> | |||||
| public Graph peak_default_graph() | public Graph peak_default_graph() | ||||
| { | |||||
| return ops.default_graph_stack.peak_controller(); | |||||
| } | |||||
| => ops.peak_default_graph(); | |||||
| /// <summary> | /// <summary> | ||||
| /// Creates a new graph. | /// Creates a new graph. | ||||
| @@ -37,19 +37,7 @@ namespace Tensorflow.Contexts | |||||
| if (shouldRunInEager) | if (shouldRunInEager) | ||||
| return eagerAction(); | return eagerAction(); | ||||
| else | else | ||||
| { | |||||
| if (executing_eagerly()) | |||||
| { | |||||
| graph_mode(); | |||||
| var result = graphAction(); | |||||
| restore_mode(); | |||||
| return result; | |||||
| } | |||||
| else | |||||
| { | |||||
| return graphAction(); | |||||
| } | |||||
| } | |||||
| return graphAction(); | |||||
| } | } | ||||
| // [DebuggerStepThrough] | // [DebuggerStepThrough] | ||||
| @@ -80,9 +80,12 @@ namespace Tensorflow.Contexts | |||||
| /// Checks whether the current thread has eager execution enabled. | /// Checks whether the current thread has eager execution enabled. | ||||
| /// </summary> | /// </summary> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DebuggerStepThrough] | |||||
| // [DebuggerStepThrough] | |||||
| public bool executing_eagerly() | public bool executing_eagerly() | ||||
| { | { | ||||
| if(context_switches.Count() == 0) | |||||
| tf.enable_eager_execution(); | |||||
| return context_switches.Current().EagerMode; | return context_switches.Current().EagerMode; | ||||
| } | } | ||||
| @@ -103,11 +106,16 @@ namespace Tensorflow.Contexts | |||||
| public void restore_mode() | public void restore_mode() | ||||
| { | { | ||||
| context_switches.Pop(); | context_switches.Pop(); | ||||
| tf.get_default_graph(); | |||||
| } | } | ||||
| public void reset_context() | public void reset_context() | ||||
| { | { | ||||
| c_api.TFE_ContextClearCaches(_handle); | |||||
| ops.reset_uid(); | |||||
| ops.reset_default_graph(); | |||||
| context_switches.Clear(); | |||||
| if (_handle != null) | |||||
| c_api.TFE_ContextClearCaches(_handle); | |||||
| } | } | ||||
| public void Dispose() | public void Dispose() | ||||
| @@ -40,11 +40,21 @@ namespace Tensorflow.Contexts | |||||
| }); | }); | ||||
| } | } | ||||
| public void Clear() | |||||
| { | |||||
| stack.Clear(); | |||||
| } | |||||
| public void Pop() | public void Pop() | ||||
| { | { | ||||
| stack.Pop(); | stack.Pop(); | ||||
| } | } | ||||
| public int Count() | |||||
| { | |||||
| return stack.Count; | |||||
| } | |||||
| public ContextSwitch Current() | public ContextSwitch Current() | ||||
| { | { | ||||
| return stack.Peek(); | return stack.Peek(); | ||||
| @@ -15,11 +15,13 @@ namespace Tensorflow | |||||
| bool preserve_cardinality = false, | bool preserve_cardinality = false, | ||||
| bool use_legacy_function = false) : base(input_dataset) | bool use_legacy_function = false) : base(input_dataset) | ||||
| { | { | ||||
| using var func = new ConcreteFunction($"autograph_{map_func.Method.Name}"); | |||||
| using var func = new ConcreteFunction($"{map_func.Method.Name}_{Guid.NewGuid()}"); | |||||
| func.Enter(); | |||||
| var input = tf.placeholder(input_dataset.element_spec[0].dtype); | var input = tf.placeholder(input_dataset.element_spec[0].dtype); | ||||
| var output = map_func(input); | var output = map_func(input); | ||||
| func.ToGraph(input, output); | func.ToGraph(input, output); | ||||
| func.Exit(); | |||||
| structure = func.OutputStructure; | structure = func.OutputStructure; | ||||
| variant_tensor = ops.map_dataset(input_dataset.variant_tensor, | variant_tensor = ops.map_dataset(input_dataset.variant_tensor, | ||||
| @@ -34,7 +34,6 @@ namespace Tensorflow.Functions | |||||
| public ConcreteFunction(string name) | public ConcreteFunction(string name) | ||||
| { | { | ||||
| func_graph = new FuncGraph(name); | func_graph = new FuncGraph(name); | ||||
| func_graph.as_default(); | |||||
| } | } | ||||
| public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs = null) | public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs = null) | ||||
| @@ -46,7 +45,7 @@ namespace Tensorflow.Functions | |||||
| public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) | public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) | ||||
| { | { | ||||
| string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | |||||
| string func_name = $"{func.Method.Name}_{Guid.NewGuid()}"; | |||||
| // IntPtr func_handle; | // IntPtr func_handle; | ||||
| using var graph = new FuncGraph(func_name); | using var graph = new FuncGraph(func_name); | ||||
| @@ -59,11 +58,12 @@ namespace Tensorflow.Functions | |||||
| new[] { input }, | new[] { input }, | ||||
| new[] { output }, | new[] { output }, | ||||
| null); | null); | ||||
| graph.Exit(); | |||||
| } | } | ||||
| public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype) | public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype) | ||||
| { | { | ||||
| string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | |||||
| string func_name = $"{func.Method.Name}_{Guid.NewGuid()}"; | |||||
| // IntPtr func_handle; | // IntPtr func_handle; | ||||
| using var graph = new FuncGraph(func_name); | using var graph = new FuncGraph(func_name); | ||||
| @@ -79,12 +79,13 @@ namespace Tensorflow.Functions | |||||
| new[] { input }, | new[] { input }, | ||||
| new[] { output.variant_tensor }, | new[] { output.variant_tensor }, | ||||
| null); | null); | ||||
| graph.Exit(); | |||||
| } | } | ||||
| public ConcreteFunction(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> func, | public ConcreteFunction(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> func, | ||||
| TF_DataType[] dtypes, TensorShape[] shapes) | TF_DataType[] dtypes, TensorShape[] shapes) | ||||
| { | { | ||||
| string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | |||||
| string func_name = $"{func.Method.Name}_{Guid.NewGuid()}"; | |||||
| // IntPtr func_handle; | // IntPtr func_handle; | ||||
| using var graph = new FuncGraph(func_name); | using var graph = new FuncGraph(func_name); | ||||
| @@ -103,6 +104,7 @@ namespace Tensorflow.Functions | |||||
| new[] { input1, input2, input3 }, | new[] { input1, input2, input3 }, | ||||
| new[] { outputs.Item1, outputs.Item2 }, | new[] { outputs.Item1, outputs.Item2 }, | ||||
| null); | null); | ||||
| graph.Exit(); | |||||
| } | } | ||||
| public void ToGraph(Tensors inputs, Tensors outputs) | public void ToGraph(Tensors inputs, Tensors outputs) | ||||
| @@ -112,10 +114,19 @@ namespace Tensorflow.Functions | |||||
| inputs, | inputs, | ||||
| outputs, | outputs, | ||||
| null); | null); | ||||
| OutputStructure = outputs.Select(x => x.ToTensorSpec()).ToArray(); | OutputStructure = outputs.Select(x => x.ToTensorSpec()).ToArray(); | ||||
| } | } | ||||
| public void Enter() | |||||
| { | |||||
| func_graph.as_default(); | |||||
| } | |||||
| public void Exit() | |||||
| { | |||||
| func_graph.Exit(); | |||||
| } | |||||
| public Tensors Invoke(Tensors inputs) | public Tensors Invoke(Tensors inputs) | ||||
| { | { | ||||
| var forward_backward = SelectForwardAndBackwardFunctions(inputs, 1, tf.Context.executing_eagerly()); | var forward_backward = SelectForwardAndBackwardFunctions(inputs, 1, tf.Context.executing_eagerly()); | ||||
| @@ -26,7 +26,6 @@ namespace Tensorflow.Functions | |||||
| var output_names = new string[0]; | var output_names = new string[0]; | ||||
| _func_graph = new FuncGraph(graph, name, attrs); | _func_graph = new FuncGraph(graph, name, attrs); | ||||
| _func_graph.as_default(); | |||||
| _func_graph.ToGraph(operations, inputs, outputs, output_names); | _func_graph.ToGraph(operations, inputs, outputs, output_names); | ||||
| } | } | ||||
| @@ -84,7 +84,7 @@ namespace Tensorflow.Functions | |||||
| } | } | ||||
| var gradients_wrt_outputs = new List<Tensor>(); | var gradients_wrt_outputs = new List<Tensor>(); | ||||
| var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{ops.uid()}"); | |||||
| var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"); | |||||
| backwards_graph.as_default(); | backwards_graph.as_default(); | ||||
| foreach (var output in trainable_outputs) | foreach (var output in trainable_outputs) | ||||
| gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape)); | gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape)); | ||||
| @@ -101,6 +101,7 @@ namespace Tensorflow.Functions | |||||
| if (!_func_graph.Outputs.Contains(capture)) | if (!_func_graph.Outputs.Contains(capture)) | ||||
| _func_graph.Outputs.Add(capture); | _func_graph.Outputs.Add(capture); | ||||
| } | } | ||||
| backwards_graph.Exit(); | |||||
| var forward_function_name = $"{_FORWARD_PREFIX}_{ops.uid()}"; | var forward_function_name = $"{_FORWARD_PREFIX}_{ops.uid()}"; | ||||
| var backward_function_attr = new Dictionary<string, string>(); | var backward_function_attr = new Dictionary<string, string>(); | ||||
| @@ -8,7 +8,7 @@ namespace Tensorflow.Graphs | |||||
| { | { | ||||
| public Func<Tensor, Tensor> to_graph(Func<Tensor, Tensor> func) | public Func<Tensor, Tensor> to_graph(Func<Tensor, Tensor> func) | ||||
| { | { | ||||
| string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | |||||
| string func_name = $"{func.Method.Name}_{Guid.NewGuid()}"; | |||||
| // IntPtr func_handle; | // IntPtr func_handle; | ||||
| using (var graph = new FuncGraph(func_name)) | using (var graph = new FuncGraph(func_name)) | ||||
| @@ -22,6 +22,7 @@ namespace Tensorflow.Graphs | |||||
| new[] { input }, | new[] { input }, | ||||
| new[] { output }, | new[] { output }, | ||||
| null); | null); | ||||
| graph.Exit(); | |||||
| } | } | ||||
| @@ -39,7 +40,7 @@ namespace Tensorflow.Graphs | |||||
| public Func<Tensor, Tensor, Tensor> to_graph(Func<Tensor, Tensor, Tensor> func) | public Func<Tensor, Tensor, Tensor> to_graph(Func<Tensor, Tensor, Tensor> func) | ||||
| { | { | ||||
| string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | |||||
| string func_name = $"{func.Method.Name}_{Guid.NewGuid()}"; | |||||
| // IntPtr func_handle; | // IntPtr func_handle; | ||||
| using (var graph = new FuncGraph(func_name)) | using (var graph = new FuncGraph(func_name)) | ||||
| @@ -54,6 +55,7 @@ namespace Tensorflow.Graphs | |||||
| new[] { input1, input2 }, | new[] { input1, input2 }, | ||||
| new[] { output }, | new[] { output }, | ||||
| null); | null); | ||||
| graph.Exit(); | |||||
| } | } | ||||
| return (Tensor a, Tensor b) => | return (Tensor a, Tensor b) => | ||||
| @@ -18,7 +18,7 @@ namespace Tensorflow.Graphs | |||||
| public override void OnEntry(MethodExecutionArgs args) | public override void OnEntry(MethodExecutionArgs args) | ||||
| { | { | ||||
| func_name = $"autograph_{args.Instance.GetHashCode()}.{args.Method.Name}"; | |||||
| func_name = $"{args.Method.Name}_{Guid.NewGuid()}"; | |||||
| if (functions.ContainsKey(func_name)) | if (functions.ContainsKey(func_name)) | ||||
| { | { | ||||
| @@ -34,6 +34,7 @@ namespace Tensorflow.Graphs | |||||
| // make function as an Operation by autograph | // make function as an Operation by autograph | ||||
| // need to restore mode when exits | // need to restore mode when exits | ||||
| function = new ConcreteFunction(func_name); | function = new ConcreteFunction(func_name); | ||||
| function.Enter(); | |||||
| // convert to Tensors | // convert to Tensors | ||||
| if (args.Arguments[0] is Tensors inputs) | if (args.Arguments[0] is Tensors inputs) | ||||
| @@ -68,6 +69,8 @@ namespace Tensorflow.Graphs | |||||
| } | } | ||||
| else | else | ||||
| function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), args.ReturnValue as Tensor); | function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), args.ReturnValue as Tensor); | ||||
| function.Exit(); | |||||
| // cache function. | // cache function. | ||||
| function.ReturnType = args.ReturnValue.GetType(); | function.ReturnType = args.ReturnValue.GetType(); | ||||
| @@ -25,63 +25,43 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public class DefaultGraphStack | public class DefaultGraphStack | ||||
| { | { | ||||
| private readonly List<StackModel> _stack = new List<StackModel>(); | |||||
| private readonly Stack<Graph> _stack = new Stack<Graph>(); | |||||
| Graph _global_default_graph; | |||||
| public void set_controller(Graph @default) | |||||
| public Graph get_default() | |||||
| { | { | ||||
| if (!_stack.Exists(x => x.Graph == @default)) | |||||
| _stack.Add(new StackModel { Graph = @default, IsDefault = true }); | |||||
| if (_stack.Count > 0) | |||||
| return _stack.Peek(); | |||||
| else if (_global_default_graph != null) | |||||
| return _global_default_graph; | |||||
| else | |||||
| _global_default_graph = new Graph(); | |||||
| foreach (var s in _stack) | |||||
| s.IsDefault = s.Graph == @default; | |||||
| return _global_default_graph; | |||||
| } | } | ||||
| public Graph get_controller() | |||||
| public Graph get_controller(Graph g) | |||||
| { | { | ||||
| if (_stack.Count == 0 || _stack.Count(x => x.IsDefault) == 0) | |||||
| _stack.Add(new StackModel { Graph = tf.Graph(), IsDefault = true }); | |||||
| for (var i = _stack.Count - 1; i >= 0; i--) | |||||
| { | |||||
| var x = _stack[i]; | |||||
| if (x.IsDefault) | |||||
| return x.Graph; | |||||
| } | |||||
| throw new TensorflowException("Unable to find a default graph"); | |||||
| _stack.Push(g); | |||||
| return g; | |||||
| } | } | ||||
| public Graph peak_controller() | public Graph peak_controller() | ||||
| { | { | ||||
| if (_stack.Count == 0 || _stack.Count(x => x.IsDefault) == 0) | |||||
| if (_stack.Count == 0) | |||||
| return null; | return null; | ||||
| for (var i = _stack.Count - 1; i >= 0; i--) | |||||
| { | |||||
| var x = _stack[i]; | |||||
| if (x.IsDefault) | |||||
| return x.Graph; | |||||
| } | |||||
| return null; | |||||
| return _stack.Peek(); | |||||
| } | } | ||||
| public bool remove(Graph g) | |||||
| public void pop() | |||||
| { | { | ||||
| if (_stack.Count == 0) | |||||
| return false; | |||||
| var sm = _stack.Find(model => model.Graph == g); | |||||
| return sm != null && _stack.Remove(sm); | |||||
| _stack.Pop(); | |||||
| } | } | ||||
| public void reset() | public void reset() | ||||
| { | { | ||||
| _stack.Clear(); | _stack.Clear(); | ||||
| } | |||||
| private class StackModel | |||||
| { | |||||
| public Graph Graph { get; set; } | |||||
| public bool IsDefault { get; set; } | |||||
| _global_default_graph = null; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -94,8 +94,6 @@ namespace Tensorflow.Graphs | |||||
| // mark_as_return | // mark_as_return | ||||
| Outputs = outputs;// .Select(x => array_ops.identity(x)).ToArray(); | Outputs = outputs;// .Select(x => array_ops.identity(x)).ToArray(); | ||||
| tf.Context.restore_mode(); | |||||
| return func_handle; | return func_handle; | ||||
| } | } | ||||
| @@ -247,9 +245,10 @@ namespace Tensorflow.Graphs | |||||
| return this; | return this; | ||||
| } | } | ||||
| protected override void DisposeManagedResources() | |||||
| public override void Exit() | |||||
| { | { | ||||
| base.DisposeManagedResources(); | |||||
| tf.Context.restore_mode(); | |||||
| ops.pop_graph(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -146,6 +146,7 @@ namespace Tensorflow | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns a context manager that makes this `Graph` the default graph. | /// Returns a context manager that makes this `Graph` the default graph. | ||||
| /// Must call Exit() to pop graph | |||||
| /// </summary> | /// </summary> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public virtual Graph as_default() | public virtual Graph as_default() | ||||
| @@ -487,7 +488,7 @@ namespace Tensorflow | |||||
| protected override void DisposeManagedResources() | protected override void DisposeManagedResources() | ||||
| { | { | ||||
| ops.default_graph_stack.remove(this); | |||||
| } | } | ||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | protected override void DisposeUnmanagedResources(IntPtr handle) | ||||
| @@ -529,6 +530,11 @@ namespace Tensorflow | |||||
| return new TensorShape(dims.Select(x => (int)x).ToArray()); | return new TensorShape(dims.Select(x => (int)x).ToArray()); | ||||
| } | } | ||||
| public virtual void Exit() | |||||
| { | |||||
| ops.pop_graph(); | |||||
| } | |||||
| string debugString = string.Empty; | string debugString = string.Empty; | ||||
| public override string ToString() | public override string ToString() | ||||
| { | { | ||||
| @@ -95,7 +95,7 @@ namespace Tensorflow.Graphs | |||||
| _copy_non_source(op, graph, op_map, base_graph); | _copy_non_source(op, graph, op_map, base_graph); | ||||
| } | } | ||||
| tf.Context.restore_mode(); | |||||
| graph.Exit(); | |||||
| return op_map; | return op_map; | ||||
| } | } | ||||
| @@ -30,7 +30,7 @@ namespace Tensorflow | |||||
| public Operation _apply_op_helper(string op_type_name, string name = null, Dictionary<string, object> keywords = null) | public Operation _apply_op_helper(string op_type_name, string name = null, Dictionary<string, object> keywords = null) | ||||
| { | { | ||||
| var g = ops.get_default_graph(); | |||||
| var g = ops._get_graph_from_inputs(keywords == null ? new object[0] : keywords.Values.ToArray()); | |||||
| var op_def = g.GetOpDef(op_type_name); | var op_def = g.GetOpDef(op_type_name); | ||||
| // Default name if not specified. | // Default name if not specified. | ||||
| @@ -59,7 +59,8 @@ namespace Tensorflow | |||||
| var input_types = new List<TF_DataType>(); | var input_types = new List<TF_DataType>(); | ||||
| object values = null; | object values = null; | ||||
| return tf_with(ops.name_scope(name), scope => | |||||
| g.as_default(); | |||||
| var ret_op = tf_with(ops.name_scope(name), scope => | |||||
| { | { | ||||
| var inferred_from = new Dictionary<string, object>(); | var inferred_from = new Dictionary<string, object>(); | ||||
| var base_types = new List<TF_DataType>(); | var base_types = new List<TF_DataType>(); | ||||
| @@ -249,6 +250,8 @@ namespace Tensorflow | |||||
| return op; | return op; | ||||
| }); | }); | ||||
| g.Exit(); | |||||
| return ret_op; | |||||
| } | } | ||||
| private void _MaybeColocateWith(ITensorOrOperation[] inputs) | private void _MaybeColocateWith(ITensorOrOperation[] inputs) | ||||
| @@ -78,6 +78,21 @@ namespace Tensorflow | |||||
| return get_default_graph().get_collection_ref<T>(key); | return get_default_graph().get_collection_ref<T>(key); | ||||
| } | } | ||||
| public static Graph _get_graph_from_inputs(params object[] op_input_list) | |||||
| { | |||||
| var current_default_graph = get_default_graph(); | |||||
| if (current_default_graph.building_function) | |||||
| return current_default_graph; | |||||
| Graph graph = null; | |||||
| foreach (var op_input in op_input_list) | |||||
| { | |||||
| if (op_input is Tensor op_input_tensor) | |||||
| graph = graph ?? op_input_tensor.graph; | |||||
| } | |||||
| return graph ?? current_default_graph; | |||||
| } | |||||
| public static Graph _get_graph_from_inputs(Tensors op_input_list) | public static Graph _get_graph_from_inputs(Tensors op_input_list) | ||||
| => _get_graph_from_inputs(op_input_list: op_input_list, graph: null); | => _get_graph_from_inputs(op_input_list: op_input_list, graph: null); | ||||
| @@ -337,6 +352,11 @@ namespace Tensorflow | |||||
| return Interlocked.Increment(ref uid_number); | return Interlocked.Increment(ref uid_number); | ||||
| } | } | ||||
| public static void reset_uid() | |||||
| { | |||||
| uid_number = -1; | |||||
| } | |||||
| public static void colocate_with(bool ignore_existing = false) | public static void colocate_with(bool ignore_existing = false) | ||||
| { | { | ||||
| _colocate_with_for_gradient(null, null, ignore_existing); | _colocate_with_for_gradient(null, null, ignore_existing); | ||||
| @@ -118,16 +118,10 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Graph get_default_graph() | public static Graph get_default_graph() | ||||
| { | |||||
| //return _default_graph_stack.get_default() | |||||
| return default_graph_stack.get_controller(); | |||||
| } | |||||
| => default_graph_stack.get_default(); | |||||
| public static Graph set_default_graph(Graph graph) | |||||
| { | |||||
| default_graph_stack.set_controller(graph); | |||||
| return default_graph_stack.get_controller(); | |||||
| } | |||||
| public static Graph set_default_graph(Graph g) | |||||
| => default_graph_stack.get_controller(g); | |||||
| /// <summary> | /// <summary> | ||||
| /// Clears the default graph stack and resets the global default graph. | /// Clears the default graph stack and resets the global default graph. | ||||
| @@ -147,5 +141,11 @@ namespace Tensorflow | |||||
| // "exit the nesting and create a new graph."); | // "exit the nesting and create a new graph."); | ||||
| default_graph_stack.reset(); | default_graph_stack.reset(); | ||||
| } | } | ||||
| public static Graph peak_default_graph() | |||||
| => default_graph_stack.peak_controller(); | |||||
| public static void pop_graph() | |||||
| => default_graph_stack.pop(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -115,10 +115,10 @@ namespace Tensorflow.Keras | |||||
| public void reset_uids() => PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<string, int>>(); | public void reset_uids() => PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<string, int>>(); | ||||
| public void clear_session() | public void clear_session() | ||||
| { | { | ||||
| ops.reset_default_graph(); | |||||
| tf.Context.reset_context(); | |||||
| reset_uids(); | reset_uids(); | ||||
| ops.set_default_session(tf.Session(ops.get_default_graph())); | ops.set_default_session(tf.Session(ops.get_default_graph())); | ||||
| var phase = tf.placeholder_with_default(false, new int[] { }, name: "keras_learning_phase"); | |||||
| // var phase = tf.placeholder_with_default(false, new int[] { }, name: "keras_learning_phase"); | |||||
| _GRAPH_LEARNING_PHASES = new Dictionary<Graph, GraphLearningPhase>(); | _GRAPH_LEARNING_PHASES = new Dictionary<Graph, GraphLearningPhase>(); | ||||
| _GRAPH_LEARNING_PHASES[tf.get_default_graph()] = 0; | _GRAPH_LEARNING_PHASES[tf.get_default_graph()] = 0; | ||||
| } | } | ||||
| @@ -185,7 +185,7 @@ namespace Tensorflow.Keras | |||||
| return tensor_util.constant_value(outputs); | return tensor_util.constant_value(outputs); | ||||
| var source_graph = outputs.graph; | var source_graph = outputs.graph; | ||||
| using var exec_graph = _scratch_graph(); | |||||
| var exec_graph = _scratch_graph(); | |||||
| var global_graph = get_graph(); | var global_graph = get_graph(); | ||||
| if (source_graph == global_graph && exec_graph != global_graph) | if (source_graph == global_graph && exec_graph != global_graph) | ||||
| { | { | ||||
| @@ -46,7 +46,7 @@ namespace Tensorflow.Keras.Engine | |||||
| _set_mask_metadata(inputs, outputs, null); | _set_mask_metadata(inputs, outputs, null); | ||||
| }); | }); | ||||
| tf.Context.restore_mode(); | |||||
| graph.Exit(); | |||||
| return outputs; | return outputs; | ||||
| } | } | ||||
| @@ -81,8 +81,9 @@ namespace Tensorflow.Keras.Layers | |||||
| sparse: args.Sparse, | sparse: args.Sparse, | ||||
| ragged: args.Ragged); | ragged: args.Ragged); | ||||
| graph.Exit(); | |||||
| isPlaceholder = true; | isPlaceholder = true; | ||||
| tf.Context.restore_mode(); | |||||
| } | } | ||||
| // Create an input node to add to self.outbound_node | // Create an input node to add to self.outbound_node | ||||
| @@ -5,6 +5,7 @@ | |||||
| <Version>0.0.1</Version> | <Version>0.0.1</Version> | ||||
| <Description>TensorFlow Recommenders is a library for building recommender system models using TensorFlow.</Description> | <Description>TensorFlow Recommenders is a library for building recommender system models using TensorFlow.</Description> | ||||
| <PackageLicenseFile>LICENSE</PackageLicenseFile> | <PackageLicenseFile>LICENSE</PackageLicenseFile> | ||||
| <Platforms>AnyCPU;x64</Platforms> | |||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| @@ -7,12 +7,17 @@ | |||||
| <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | ||||
| <Version>0.0.1</Version> | <Version>0.0.1</Version> | ||||
| <PackageLicenseFile>LICENSE</PackageLicenseFile> | <PackageLicenseFile>LICENSE</PackageLicenseFile> | ||||
| <Platforms>AnyCPU;x64</Platforms> | |||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | ||||
| <DefineConstants>DEBUG;TRACE</DefineConstants> | <DefineConstants>DEBUG;TRACE</DefineConstants> | ||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> | |||||
| <DefineConstants>DEBUG;TRACE</DefineConstants> | |||||
| </PropertyGroup> | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <None Include="..\..\LICENSE"> | <None Include="..\..\LICENSE"> | ||||
| <Pack>True</Pack> | <Pack>True</Pack> | ||||
| @@ -1,8 +1,7 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using System; | using System; | ||||
| using Tensorflow; | |||||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| namespace Tensorflow.Native.UnitTest | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// tensorflow\c\c_api_test.cc | /// tensorflow\c\c_api_test.cc | ||||
| @@ -1,9 +1,8 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using System; | using System; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using Tensorflow; | |||||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| namespace Tensorflow.Native.UnitTest | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// tensorflow\c\c_api_test.cc | /// tensorflow\c\c_api_test.cc | ||||
| @@ -2,10 +2,9 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow; | |||||
| using static TensorFlowNET.UnitTest.c_test_util; | |||||
| using static Tensorflow.Native.UnitTest.c_test_util; | |||||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| namespace Tensorflow.Native.UnitTest | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// tensorflow\c\c_api_function_test.cc | /// tensorflow\c\c_api_function_test.cc | ||||
| @@ -1,11 +1,9 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using NumSharp; | using NumSharp; | ||||
| using System; | using System; | ||||
| using Tensorflow; | |||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| using Buffer = Tensorflow.Buffer; | |||||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| namespace Tensorflow.Native.UnitTest | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// tensorflow\c\c_api_test.cc | /// tensorflow\c\c_api_test.cc | ||||
| @@ -1,13 +1,11 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using System; | using System; | ||||
| using Tensorflow; | |||||
| using Tensorflow.Device; | using Tensorflow.Device; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using Tensorflow.UnitTest; | |||||
| namespace TensorFlowNET.UnitTest | |||||
| namespace Tensorflow.Native.UnitTest | |||||
| { | { | ||||
| public class CApiTest : GraphModeTestBase | |||||
| public class CApiTest | |||||
| { | { | ||||
| protected static readonly TF_Code TF_OK = TF_Code.TF_OK; | protected static readonly TF_Code TF_OK = TF_Code.TF_OK; | ||||
| protected static readonly TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT; | protected static readonly TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT; | ||||
| @@ -1,10 +1,9 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow; | |||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| namespace TensorFlowNET.UnitTest | |||||
| namespace Tensorflow.Native.UnitTest | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// tensorflow\c\c_test_util.cc | /// tensorflow\c\c_test_util.cc | ||||
| @@ -1,9 +1,8 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using Tensorflow; | |||||
| using Tensorflow.Device; | using Tensorflow.Device; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| namespace Tensorflow.Native.UnitTest.Eager | |||||
| { | { | ||||
| public partial class CApiEagerTest | public partial class CApiEagerTest | ||||
| { | { | ||||
| @@ -1,10 +1,9 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using System; | using System; | ||||
| using Tensorflow; | |||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| namespace Tensorflow.Native.UnitTest.Eager | |||||
| { | { | ||||
| public partial class CApiEagerTest | public partial class CApiEagerTest | ||||
| { | { | ||||
| @@ -1,8 +1,7 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using Tensorflow; | |||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| namespace Tensorflow.Native.UnitTest.Eager | |||||
| { | { | ||||
| public partial class CApiEagerTest | public partial class CApiEagerTest | ||||
| { | { | ||||
| @@ -1,8 +1,7 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using Tensorflow; | |||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| namespace Tensorflow.Native.UnitTest.Eager | |||||
| { | { | ||||
| public partial class CApiEagerTest | public partial class CApiEagerTest | ||||
| { | { | ||||
| @@ -1,8 +1,7 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using Tensorflow; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| namespace Tensorflow.Native.UnitTest.Eager | |||||
| { | { | ||||
| public partial class CApiEagerTest | public partial class CApiEagerTest | ||||
| { | { | ||||
| @@ -1,8 +1,7 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using Tensorflow; | |||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| namespace Tensorflow.Native.UnitTest.Eager | |||||
| { | { | ||||
| public partial class CApiEagerTest | public partial class CApiEagerTest | ||||
| { | { | ||||
| @@ -1,9 +1,8 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using Tensorflow; | |||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| namespace Tensorflow.Native.UnitTest.Eager | |||||
| { | { | ||||
| public partial class CApiEagerTest | public partial class CApiEagerTest | ||||
| { | { | ||||
| @@ -1,10 +1,9 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using System; | using System; | ||||
| using Tensorflow; | |||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| namespace Tensorflow.Native.UnitTest.Eager | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// tensorflow\c\eager\c_api_test.cc | /// tensorflow\c\eager\c_api_test.cc | ||||
| @@ -1,13 +1,12 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using System; | using System; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest.Gradient | |||||
| namespace Tensorflow.Native.UnitTest.Eager | |||||
| { | { | ||||
| [TestClass] | [TestClass] | ||||
| public class GradientEagerTest : PythonTest | |||||
| public class GradientEagerTest | |||||
| { | { | ||||
| [TestMethod] | [TestMethod] | ||||
| public void ConstantSquare() | public void ConstantSquare() | ||||
| @@ -1,8 +1,7 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using Tensorflow; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| namespace Tensorflow.Native.UnitTest | |||||
| { | { | ||||
| [TestClass] | [TestClass] | ||||
| public class GraphBuildTest : CApiTest | public class GraphBuildTest : CApiTest | ||||
| @@ -1,10 +1,8 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using System; | using System; | ||||
| using Tensorflow; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using Buffer = Tensorflow.Buffer; | |||||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| namespace Tensorflow.Native.UnitTest | |||||
| { | { | ||||
| [TestClass] | [TestClass] | ||||
| public class GraphTest : CApiTest | public class GraphTest : CApiTest | ||||
| @@ -0,0 +1,74 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| namespace Tensorflow.Native.UnitTest.Sessions | |||||
| { | |||||
| [TestClass, Ignore] | |||||
| public class SessionTest : CApiTest | |||||
| { | |||||
| /// <summary> | |||||
| /// tensorflow\c\c_api_test.cc | |||||
| /// `TEST(CAPI, Session)` | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public void Session() | |||||
| { | |||||
| using var s = new Status(); | |||||
| using var graph = new Graph(); | |||||
| // Make a placeholder operation. | |||||
| var feed = c_test_util.Placeholder(graph, s); | |||||
| // Make a constant operation with the scalar "2". | |||||
| var two = c_test_util.ScalarConst(2, graph, s); | |||||
| // Add operation. | |||||
| var add = c_test_util.Add(feed, two, graph, s); | |||||
| var csession = new CSession(graph, s); | |||||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||||
| // Run the graph. | |||||
| var inputs = new Dictionary<Operation, Tensor>(); | |||||
| inputs.Add(feed, new Tensor(3)); | |||||
| csession.SetInputs(inputs); | |||||
| var outputs = new TF_Output[] { new TF_Output(add, 0) }; | |||||
| csession.SetOutputs(outputs); | |||||
| csession.Run(s); | |||||
| Tensor outTensor = csession.output_tensor(0); | |||||
| EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); | |||||
| EXPECT_EQ(0, outTensor.NDims); | |||||
| ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); | |||||
| var output_contents = outTensor.ToArray<int>(); | |||||
| EXPECT_EQ(3 + 2, output_contents[0]); | |||||
| // Add another operation to the graph. | |||||
| var neg = c_test_util.Neg(add, graph, s); | |||||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||||
| // Run up to the new operation. | |||||
| inputs = new Dictionary<Operation, Tensor>(); | |||||
| inputs.Add(feed, new Tensor(7)); | |||||
| csession.SetInputs(inputs); | |||||
| outputs = new TF_Output[] { new TF_Output(neg, 0) }; | |||||
| csession.SetOutputs(outputs); | |||||
| csession.Run(s); | |||||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||||
| outTensor = csession.output_tensor(0); | |||||
| ASSERT_TRUE(outTensor != IntPtr.Zero); | |||||
| EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); | |||||
| EXPECT_EQ(0, outTensor.NDims); // scalar | |||||
| ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); | |||||
| output_contents = outTensor.ToArray<int>(); | |||||
| EXPECT_EQ(-(7 + 2), output_contents[0]); | |||||
| // Clean up | |||||
| csession.CloseAndDelete(s); | |||||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,36 @@ | |||||
| <Project Sdk="Microsoft.NET.Sdk"> | |||||
| <PropertyGroup> | |||||
| <TargetFramework>netcoreapp3.1</TargetFramework> | |||||
| <IsPackable>false</IsPackable> | |||||
| <Platforms>AnyCPU;x64</Platforms> | |||||
| </PropertyGroup> | |||||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||||
| <AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||||
| <DefineConstants>DEBUG;TRACE</DefineConstants> | |||||
| <PlatformTarget>x64</PlatformTarget> | |||||
| </PropertyGroup> | |||||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> | |||||
| <AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||||
| <DefineConstants>DEBUG;TRACE</DefineConstants> | |||||
| <PlatformTarget>x64</PlatformTarget> | |||||
| </PropertyGroup> | |||||
| <ItemGroup> | |||||
| <PackageReference Include="FluentAssertions" Version="5.10.3" /> | |||||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.8.3" /> | |||||
| <PackageReference Include="MSTest.TestAdapter" Version="2.1.2" /> | |||||
| <PackageReference Include="MSTest.TestFramework" Version="2.1.2" /> | |||||
| <PackageReference Include="coverlet.collector" Version="1.3.0" /> | |||||
| <PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.1" /> | |||||
| </ItemGroup> | |||||
| <ItemGroup> | |||||
| <ProjectReference Include="..\..\src\TensorFlowNET.Core\Tensorflow.Binding.csproj" /> | |||||
| </ItemGroup> | |||||
| </Project> | |||||
| @@ -0,0 +1,204 @@ | |||||
| using FluentAssertions; | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using NumSharp; | |||||
| using System; | |||||
| using System.Linq; | |||||
| using System.Runtime.InteropServices; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Native.UnitTest.Tensors | |||||
| { | |||||
| [TestClass] | |||||
| public class TensorTest : CApiTest | |||||
| { | |||||
| [TestMethod] | |||||
| public unsafe void TensorFromFixed() | |||||
| { | |||||
| var array = new float[1000]; | |||||
| var span = new Span<float>(array, 100, 500); | |||||
| fixed (float* ptr = &MemoryMarshal.GetReference(span)) | |||||
| { | |||||
| using (var t = new Tensor((IntPtr)ptr, new long[] { span.Length }, tf.float32, 4 * span.Length)) | |||||
| { | |||||
| Assert.IsFalse(t.IsDisposed); | |||||
| Assert.AreEqual(2000, (int)t.bytesize); | |||||
| } | |||||
| } | |||||
| fixed (float* ptr = &array[0]) | |||||
| { | |||||
| using (var t = new Tensor((IntPtr)ptr, new long[] { array.Length }, tf.float32, 4 * array.Length)) | |||||
| { | |||||
| Assert.IsFalse(t.IsDisposed); | |||||
| Assert.AreEqual(4000, (int)t.bytesize); | |||||
| } | |||||
| } | |||||
| } | |||||
| [TestMethod] | |||||
| public void TensorFromArray() | |||||
| { | |||||
| var array = new float[1000]; | |||||
| using (var t = new Tensor(array, new long[] { array.Length }, tf.float32)) | |||||
| { | |||||
| Assert.IsFalse(t.IsDisposed); | |||||
| Assert.AreEqual(1000 * sizeof(float), (int)t.bytesize); | |||||
| } | |||||
| using (var t = new Tensor(new float[] { 1 }, new long[] { 1 }, tf.float32)) | |||||
| { | |||||
| Assert.IsFalse(t.IsDisposed); | |||||
| Assert.AreEqual(1 * sizeof(float), (int)t.bytesize); | |||||
| } | |||||
| using (var t = new Tensor(new float[] { 1 }, null, tf.float32)) | |||||
| { | |||||
| Assert.IsFalse(t.IsDisposed); | |||||
| Assert.AreEqual(1 * sizeof(float), (int)t.bytesize); | |||||
| t.shape.Should().BeEmpty(); | |||||
| } | |||||
| } | |||||
| [TestMethod] | |||||
| public void AllocateTensor() | |||||
| { | |||||
| ulong num_bytes = 6 * sizeof(float); | |||||
| long[] dims = { 2, 3 }; | |||||
| Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes); | |||||
| EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype); | |||||
| EXPECT_EQ(2, t.NDims); | |||||
| EXPECT_EQ((int)dims[0], t.shape[0]); | |||||
| EXPECT_EQ(num_bytes, t.bytesize); | |||||
| t.Dispose(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Port from c_api_test.cc | |||||
| /// `TEST(CAPI, MaybeMove)` | |||||
| /// </summary> | |||||
| [TestMethod, Ignore] | |||||
| public void MaybeMove() | |||||
| { | |||||
| NDArray nd = np.array(2, 3); | |||||
| Tensor t = new Tensor(nd); | |||||
| Tensor o = t.MaybeMove(); | |||||
| ASSERT_TRUE(o == IntPtr.Zero); // It is unsafe to move memory TF might not own. | |||||
| t.Dispose(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Port from c_api_test.cc | |||||
| /// `TEST(CAPI, Tensor)` | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public void Tensor() | |||||
| { | |||||
| var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); | |||||
| var tensor = new Tensor(nd); | |||||
| var array = tensor.ToArray<float>(); | |||||
| EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT); | |||||
| EXPECT_EQ(tensor.rank, nd.ndim); | |||||
| EXPECT_EQ((int)tensor.shape[0], nd.shape[0]); | |||||
| EXPECT_EQ((int)tensor.shape[1], nd.shape[1]); | |||||
| EXPECT_EQ(tensor.bytesize, (ulong)nd.size * sizeof(float)); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), new float[] { 1, 2, 3, 4, 5, 6 })); | |||||
| } | |||||
| /// <summary> | |||||
| /// Port from tensorflow\c\c_api_test.cc | |||||
| /// `TEST(CAPI, SetShape)` | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public void SetShape() | |||||
| { | |||||
| var s = new Status(); | |||||
| var graph = new Graph().as_default(); | |||||
| var feed = c_test_util.Placeholder(graph, s); | |||||
| var feed_out_0 = new TF_Output(feed, 0); | |||||
| // Fetch the shape, it should be completely unknown. | |||||
| int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| EXPECT_EQ(-1, num_dims); | |||||
| // Set the shape to be unknown, expect no change. | |||||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s.Handle); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle); | |||||
| EXPECT_EQ(-1, num_dims); | |||||
| // Set the shape to be 2 x Unknown | |||||
| long[] dims = { 2, -1 }; | |||||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s.Handle); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle); | |||||
| EXPECT_EQ(2, num_dims); | |||||
| // Get the dimension vector appropriately. | |||||
| var returned_dims = new long[dims.Length]; | |||||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | |||||
| // Set to a new valid shape: [2, 3] | |||||
| dims[1] = 3; | |||||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s.Handle); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| // Fetch and see that the new value is returned. | |||||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | |||||
| // Try to set 'unknown' with unknown rank on the shape and see that | |||||
| // it doesn't change. | |||||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s.Handle); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| EXPECT_EQ(2, num_dims); | |||||
| EXPECT_EQ(2, (int)returned_dims[0]); | |||||
| EXPECT_EQ(3, (int)returned_dims[1]); | |||||
| // Try to set 'unknown' with same rank on the shape and see that | |||||
| // it doesn't change. | |||||
| dims[0] = -1; | |||||
| dims[1] = -1; | |||||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s.Handle); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| EXPECT_EQ(2, num_dims); | |||||
| EXPECT_EQ(2, (int)returned_dims[0]); | |||||
| EXPECT_EQ(3, (int)returned_dims[1]); | |||||
| // Try to fetch a shape with the wrong num_dims | |||||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s.Handle); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); | |||||
| // Try to set an invalid shape (cannot change 2x3 to a 2x5). | |||||
| dims[1] = 5; | |||||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s.Handle); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); | |||||
| // Test for a scalar. | |||||
| var three = c_test_util.ScalarConst(3, graph, s); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| var three_out_0 = new TF_Output(three, 0); | |||||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s.Handle); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| EXPECT_EQ(0, num_dims); | |||||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, dims, num_dims, s.Handle); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); | |||||
| graph.Exit(); | |||||
| s.Dispose(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,10 +1,8 @@ | |||||
| using System; | using System; | ||||
| using System.Diagnostics.CodeAnalysis; | using System.Diagnostics.CodeAnalysis; | ||||
| using Tensorflow; | |||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| using Buffer = Tensorflow.Buffer; | |||||
| namespace TensorFlowNET.UnitTest | |||||
| namespace Tensorflow.Native.UnitTest | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Port from `tensorflow\c\c_test_util.cc` | /// Port from `tensorflow\c\c_test_util.cc` | ||||
| @@ -8,78 +8,11 @@ using Tensorflow; | |||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| [TestClass] | |||||
| public class SessionTest : CApiTest | |||||
| [TestClass, Ignore] | |||||
| public class SessionTest | |||||
| { | { | ||||
| /// <summary> | |||||
| /// tensorflow\c\c_api_test.cc | |||||
| /// `TEST(CAPI, Session)` | |||||
| /// </summary> | |||||
| [TestMethod, Ignore] | |||||
| public void Session() | |||||
| { | |||||
| lock (Locks.ProcessWide) | |||||
| { | |||||
| var s = new Status(); | |||||
| var graph = new Graph().as_default(); | |||||
| // Make a placeholder operation. | |||||
| var feed = c_test_util.Placeholder(graph, s); | |||||
| // Make a constant operation with the scalar "2". | |||||
| var two = c_test_util.ScalarConst(2, graph, s); | |||||
| // Add operation. | |||||
| var add = c_test_util.Add(feed, two, graph, s); | |||||
| var csession = new CSession(graph, s); | |||||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||||
| // Run the graph. | |||||
| var inputs = new Dictionary<Operation, Tensor>(); | |||||
| inputs.Add(feed, new Tensor(3)); | |||||
| csession.SetInputs(inputs); | |||||
| var outputs = new TF_Output[] { new TF_Output(add, 0) }; | |||||
| csession.SetOutputs(outputs); | |||||
| csession.Run(s); | |||||
| Tensor outTensor = csession.output_tensor(0); | |||||
| EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); | |||||
| EXPECT_EQ(0, outTensor.NDims); | |||||
| ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); | |||||
| var output_contents = outTensor.ToArray<int>(); | |||||
| EXPECT_EQ(3 + 2, output_contents[0]); | |||||
| // Add another operation to the graph. | |||||
| var neg = c_test_util.Neg(add, graph, s); | |||||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||||
| // Run up to the new operation. | |||||
| inputs = new Dictionary<Operation, Tensor>(); | |||||
| inputs.Add(feed, new Tensor(7)); | |||||
| csession.SetInputs(inputs); | |||||
| outputs = new TF_Output[] { new TF_Output(neg, 0) }; | |||||
| csession.SetOutputs(outputs); | |||||
| csession.Run(s); | |||||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||||
| outTensor = csession.output_tensor(0); | |||||
| ASSERT_TRUE(outTensor != IntPtr.Zero); | |||||
| EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); | |||||
| EXPECT_EQ(0, outTensor.NDims); // scalar | |||||
| ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); | |||||
| output_contents = outTensor.ToArray<int>(); | |||||
| EXPECT_EQ(-(7 + 2), output_contents[0]); | |||||
| // Clean up | |||||
| csession.CloseAndDelete(s); | |||||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||||
| } | |||||
| } | |||||
| [TestMethod] | [TestMethod] | ||||
| public void EvalTensor() | public void EvalTensor() | ||||
| { | { | ||||
| @@ -7,201 +7,11 @@ using System.Runtime.InteropServices; | |||||
| using Tensorflow; | using Tensorflow; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| [TestClass] | |||||
| public class TensorTest : CApiTest | |||||
| [TestClass, Ignore] | |||||
| public class TensorTest | |||||
| { | { | ||||
| [TestMethod] | |||||
| public unsafe void TensorFromFixed() | |||||
| { | |||||
| var array = new float[1000]; | |||||
| var span = new Span<float>(array, 100, 500); | |||||
| fixed (float* ptr = &MemoryMarshal.GetReference(span)) | |||||
| { | |||||
| using (var t = new Tensor((IntPtr)ptr, new long[] { span.Length }, tf.float32, 4 * span.Length)) | |||||
| { | |||||
| Assert.IsFalse(t.IsDisposed); | |||||
| Assert.AreEqual(2000, (int)t.bytesize); | |||||
| } | |||||
| } | |||||
| fixed (float* ptr = &array[0]) | |||||
| { | |||||
| using (var t = new Tensor((IntPtr)ptr, new long[] { array.Length }, tf.float32, 4 * array.Length)) | |||||
| { | |||||
| Assert.IsFalse(t.IsDisposed); | |||||
| Assert.AreEqual(4000, (int)t.bytesize); | |||||
| } | |||||
| } | |||||
| } | |||||
| [TestMethod] | |||||
| public unsafe void TensorFromArray() | |||||
| { | |||||
| var array = new float[1000]; | |||||
| using (var t = new Tensor(array, new long[] { array.Length }, tf.float32)) | |||||
| { | |||||
| Assert.IsFalse(t.IsDisposed); | |||||
| Assert.AreEqual(1000 * sizeof(float), (int)t.bytesize); | |||||
| } | |||||
| using (var t = new Tensor(new float[] { 1 }, new long[] { 1 }, tf.float32)) | |||||
| { | |||||
| Assert.IsFalse(t.IsDisposed); | |||||
| Assert.AreEqual(1 * sizeof(float), (int)t.bytesize); | |||||
| } | |||||
| using (var t = new Tensor(new float[] { 1 }, null, tf.float32)) | |||||
| { | |||||
| Assert.IsFalse(t.IsDisposed); | |||||
| Assert.AreEqual(1 * sizeof(float), (int)t.bytesize); | |||||
| t.shape.Should().BeEmpty(); | |||||
| } | |||||
| } | |||||
| [TestMethod] | |||||
| public void AllocateTensor() | |||||
| { | |||||
| ulong num_bytes = 6 * sizeof(float); | |||||
| long[] dims = { 2, 3 }; | |||||
| Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes); | |||||
| EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype); | |||||
| EXPECT_EQ(2, t.NDims); | |||||
| EXPECT_EQ((int)dims[0], t.shape[0]); | |||||
| EXPECT_EQ(num_bytes, t.bytesize); | |||||
| t.Dispose(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Port from c_api_test.cc | |||||
| /// `TEST(CAPI, MaybeMove)` | |||||
| /// </summary> | |||||
| [TestMethod, Ignore] | |||||
| public void MaybeMove() | |||||
| { | |||||
| NDArray nd = np.array(2, 3); | |||||
| Tensor t = new Tensor(nd); | |||||
| Tensor o = t.MaybeMove(); | |||||
| ASSERT_TRUE(o == IntPtr.Zero); // It is unsafe to move memory TF might not own. | |||||
| t.Dispose(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Port from c_api_test.cc | |||||
| /// `TEST(CAPI, Tensor)` | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public void Tensor() | |||||
| { | |||||
| var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); | |||||
| var tensor = new Tensor(nd); | |||||
| var array = tensor.ToArray<float>(); | |||||
| EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT); | |||||
| EXPECT_EQ(tensor.rank, nd.ndim); | |||||
| EXPECT_EQ((int)tensor.shape[0], nd.shape[0]); | |||||
| EXPECT_EQ((int)tensor.shape[1], nd.shape[1]); | |||||
| EXPECT_EQ(tensor.bytesize, (ulong)nd.size * sizeof(float)); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), new float[] { 1, 2, 3, 4, 5, 6 })); | |||||
| } | |||||
| /// <summary> | |||||
| /// Port from tensorflow\c\c_api_test.cc | |||||
| /// `TEST(CAPI, SetShape)` | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public void SetShape() | |||||
| { | |||||
| var s = new Status(); | |||||
| var graph = new Graph().as_default(); | |||||
| var feed = c_test_util.Placeholder(graph, s); | |||||
| var feed_out_0 = new TF_Output(feed, 0); | |||||
| // Fetch the shape, it should be completely unknown. | |||||
| int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| EXPECT_EQ(-1, num_dims); | |||||
| // Set the shape to be unknown, expect no change. | |||||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s.Handle); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle); | |||||
| EXPECT_EQ(-1, num_dims); | |||||
| // Set the shape to be 2 x Unknown | |||||
| long[] dims = { 2, -1 }; | |||||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s.Handle); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle); | |||||
| EXPECT_EQ(2, num_dims); | |||||
| // Get the dimension vector appropriately. | |||||
| var returned_dims = new long[dims.Length]; | |||||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | |||||
| // Set to a new valid shape: [2, 3] | |||||
| dims[1] = 3; | |||||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s.Handle); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| // Fetch and see that the new value is returned. | |||||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | |||||
| // Try to set 'unknown' with unknown rank on the shape and see that | |||||
| // it doesn't change. | |||||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s.Handle); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| EXPECT_EQ(2, num_dims); | |||||
| EXPECT_EQ(2, (int)returned_dims[0]); | |||||
| EXPECT_EQ(3, (int)returned_dims[1]); | |||||
| // Try to set 'unknown' with same rank on the shape and see that | |||||
| // it doesn't change. | |||||
| dims[0] = -1; | |||||
| dims[1] = -1; | |||||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s.Handle); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| EXPECT_EQ(2, num_dims); | |||||
| EXPECT_EQ(2, (int)returned_dims[0]); | |||||
| EXPECT_EQ(3, (int)returned_dims[1]); | |||||
| // Try to fetch a shape with the wrong num_dims | |||||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s.Handle); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); | |||||
| // Try to set an invalid shape (cannot change 2x3 to a 2x5). | |||||
| dims[1] = 5; | |||||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s.Handle); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); | |||||
| // Test for a scalar. | |||||
| var three = c_test_util.ScalarConst(3, graph, s); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| var three_out_0 = new TF_Output(three, 0); | |||||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s.Handle); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| EXPECT_EQ(0, num_dims); | |||||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, null, num_dims, s.Handle); | |||||
| //Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| // graph.Dispose(); | |||||
| s.Dispose(); | |||||
| } | |||||
| [TestMethod] | [TestMethod] | ||||
| public void sparse_to_dense() | public void sparse_to_dense() | ||||
| { | { | ||||
| @@ -271,32 +81,5 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, masked.ToArray<int>())); | Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, masked.ToArray<int>())); | ||||
| } | } | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Creates a tensor from an image of 256x256x3 and resizes it to 100x100x3 | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public unsafe void tensor_resize() | |||||
| { | |||||
| tf.enable_eager_execution(); | |||||
| var imageArray = new float[256 * 256 * 3]; | |||||
| using var newSize = tf.convert_to_tensor(new int[] { 100, 100 }); | |||||
| using (var t = tf.constant(imageArray, tf.float32, (1, 256, 256, 3))) | |||||
| { | |||||
| Assert.IsFalse(t.IsDisposed); | |||||
| Assert.AreEqual(256 * 256 * 3 * sizeof(float), (int)t.bytesize); | |||||
| using var resized = tf.image.resize_bilinear(t, newSize); | |||||
| EXPECT_EQ(resized.shape[0], 1); | |||||
| EXPECT_EQ(resized.shape[1], 100); | |||||
| EXPECT_EQ(resized.shape[2], 100); | |||||
| EXPECT_EQ(resized.shape[3], 3); | |||||
| } | |||||
| tf.compat.v1.disable_eager_execution(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,4 +1,5 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using System; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest | namespace TensorFlowNET.UnitTest | ||||
| @@ -10,11 +11,40 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| if (!tf.executing_eagerly()) | if (!tf.executing_eagerly()) | ||||
| tf.enable_eager_execution(); | tf.enable_eager_execution(); | ||||
| tf.Context.ensure_initialized(); | |||||
| } | } | ||||
| [TestCleanup] | [TestCleanup] | ||||
| public void TestClean() | public void TestClean() | ||||
| { | { | ||||
| } | } | ||||
| public bool Equal(float[] f1, float[] f2) | |||||
| { | |||||
| bool ret = false; | |||||
| var tolerance = .000001f; | |||||
| for (var i = 0; i < f1.Length; i++) | |||||
| { | |||||
| ret = Math.Abs(f1[i] - f2[i]) <= tolerance; | |||||
| if (!ret) | |||||
| break; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| public bool Equal(double[] d1, double[] d2) | |||||
| { | |||||
| bool ret = false; | |||||
| var tolerance = .000000000000001f; | |||||
| for (var i = 0; i < d1.Length; i++) | |||||
| { | |||||
| ret = Math.Abs(d1[i] - d2[i]) <= tolerance; | |||||
| if (!ret) | |||||
| break; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -6,10 +6,10 @@ using System.Threading; | |||||
| using Tensorflow; | using Tensorflow; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| [TestClass] | [TestClass] | ||||
| public class EnforcedSinglethreadingTests : CApiTest | |||||
| public class EnforcedSinglethreadingTests | |||||
| { | { | ||||
| private static readonly object _singlethreadLocker = new object(); | private static readonly object _singlethreadLocker = new object(); | ||||
| @@ -1,6 +1,7 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using TensorFlowNET.UnitTest; | using TensorFlowNET.UnitTest; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using static Tensorflow.KerasApi; | |||||
| namespace Tensorflow.UnitTest | namespace Tensorflow.UnitTest | ||||
| { | { | ||||
| @@ -15,6 +16,7 @@ namespace Tensorflow.UnitTest | |||||
| [TestCleanup] | [TestCleanup] | ||||
| public void TestClean() | public void TestClean() | ||||
| { | { | ||||
| keras.backend.clear_session(); | |||||
| tf.enable_eager_execution(); | tf.enable_eager_execution(); | ||||
| } | } | ||||
| } | } | ||||
| @@ -5,7 +5,7 @@ using static Tensorflow.Binding; | |||||
| namespace TensorFlowNET.UnitTest.nn_test | namespace TensorFlowNET.UnitTest.nn_test | ||||
| { | { | ||||
| [TestClass] | [TestClass] | ||||
| public class ActivationFunctionTest : TFNetApiTest | |||||
| public class ActivationFunctionTest : EagerModeTestBase | |||||
| { | { | ||||
| // A constant vector of size 6 | // A constant vector of size 6 | ||||
| Tensor a = tf.constant(new float[] { 1.0f, -0.5f, 3.4f, -2.1f, 0.0f, -6.5f }); | Tensor a = tf.constant(new float[] { 1.0f, -0.5f, 3.4f, -2.1f, 0.0f, -6.5f }); | ||||
| @@ -6,7 +6,7 @@ using static Tensorflow.Binding; | |||||
| namespace TensorFlowNET.UnitTest.ManagedAPI | namespace TensorFlowNET.UnitTest.ManagedAPI | ||||
| { | { | ||||
| [TestClass] | [TestClass] | ||||
| public class BitwiseApiTest : TFNetApiTest | |||||
| public class BitwiseApiTest : EagerModeTestBase | |||||
| { | { | ||||
| [TestInitialize] | [TestInitialize] | ||||
| public void Init() | public void Init() | ||||
| @@ -7,7 +7,7 @@ using static Tensorflow.Binding; | |||||
| namespace TensorFlowNET.UnitTest.ManagedAPI | namespace TensorFlowNET.UnitTest.ManagedAPI | ||||
| { | { | ||||
| [TestClass] | [TestClass] | ||||
| public class FunctionApiTest : TFNetApiTest | |||||
| public class FunctionApiTest : EagerModeTestBase | |||||
| { | { | ||||
| Tensor Min(Tensor a, Tensor b) | Tensor Min(Tensor a, Tensor b) | ||||
| { | { | ||||
| @@ -6,7 +6,7 @@ using static Tensorflow.Binding; | |||||
| namespace TensorFlowNET.UnitTest.ManagedAPI | namespace TensorFlowNET.UnitTest.ManagedAPI | ||||
| { | { | ||||
| [TestClass] | [TestClass] | ||||
| public class MathApiTest : TFNetApiTest | |||||
| public class MathApiTest : EagerModeTestBase | |||||
| { | { | ||||
| // A constant vector of size 6 | // A constant vector of size 6 | ||||
| Tensor a = tf.constant(new float[] { 1.0f, -0.5f, 3.4f, -2.1f, 0.0f, -6.5f }); | Tensor a = tf.constant(new float[] { 1.0f, -0.5f, 3.4f, -2.1f, 0.0f, -6.5f }); | ||||
| @@ -1,35 +0,0 @@ | |||||
| using System; | |||||
| namespace TensorFlowNET.UnitTest | |||||
| { | |||||
| public class TFNetApiTest | |||||
| { | |||||
| public bool Equal(float[] f1, float[] f2) | |||||
| { | |||||
| bool ret = false; | |||||
| var tolerance = .000001f; | |||||
| for (var i = 0; i < f1.Length; i++) | |||||
| { | |||||
| ret = Math.Abs(f1[i] - f2[i]) <= tolerance; | |||||
| if (!ret) | |||||
| break; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| public bool Equal(double[] d1, double[] d2) | |||||
| { | |||||
| bool ret = false; | |||||
| var tolerance = .000000000000001f; | |||||
| for (var i = 0; i < d1.Length; i++) | |||||
| { | |||||
| ret = Math.Abs(d1[i] - d2[i]) <= tolerance; | |||||
| if (!ret) | |||||
| break; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,86 +0,0 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using NumSharp; | |||||
| using System; | |||||
| using System.Linq; | |||||
| using Tensorflow; | |||||
| using Tensorflow.UnitTest; | |||||
| namespace TensorFlowNET.UnitTest.nn_test | |||||
| { | |||||
| [TestClass] | |||||
| public class ZeroFractionTest : GraphModeTestBase | |||||
| { | |||||
| protected double _ZeroFraction(NDArray x) | |||||
| { | |||||
| assert(x.shape); | |||||
| int total_elements = np.prod(x.shape); | |||||
| var eps = 1e-8; | |||||
| var nonzeros = x.Data<double>().Count(d => Math.Abs(d) > eps); | |||||
| return 1.0 - nonzeros / (double)total_elements; | |||||
| } | |||||
| [Ignore("TODO implement nn_impl.zero_fraction")] | |||||
| [TestMethod] | |||||
| public void testZeroFraction() | |||||
| { | |||||
| var x_shape = new Shape(5, 17); | |||||
| var x_np = np.random.randint(0, 2, x_shape); | |||||
| //x_np.astype(np.float32); | |||||
| var y_np = this._ZeroFraction(x_np); | |||||
| var x_tf = constant_op.constant(x_np); | |||||
| x_tf.set_shape(x_shape); | |||||
| var y_tf = nn_impl.zero_fraction(x_tf); | |||||
| var y_tf_np = self.evaluate<NDArray>(y_tf); | |||||
| var eps = 1e-8; | |||||
| self.assertAllClose(y_tf_np, y_np, eps); | |||||
| } | |||||
| [Ignore("TODO implement nn_impl.zero_fraction")] | |||||
| [TestMethod] | |||||
| public void testZeroFractionEmpty() | |||||
| { | |||||
| var x = np.zeros(0); | |||||
| var y = self.evaluate<NDArray>(nn_impl.zero_fraction(new Tensor(x))); | |||||
| self.assertTrue(np.isnan(y)); | |||||
| } | |||||
| [Ignore("TODO implement nn_impl.zero_fraction")] | |||||
| [TestMethod] | |||||
| public void testZeroFraction2_27Zeros() | |||||
| { | |||||
| var sparsity = nn_impl.zero_fraction( | |||||
| array_ops.zeros(new Shape((int)Math.Pow(2, 27 * 1.01)), dtypes.int8)); | |||||
| self.assertAllClose(1.0, self.evaluate<NDArray>(sparsity)); | |||||
| } | |||||
| [Ignore("TODO implement nn_impl.zero_fraction")] | |||||
| [TestMethod] | |||||
| public void testZeroFraction2_27Ones() | |||||
| { | |||||
| var sparsity = nn_impl.zero_fraction( | |||||
| array_ops.ones(new TensorShape((int)Math.Pow(2, 27 * 1.01)), dtypes.int8)); | |||||
| self.assertAllClose(0.0, self.evaluate<NDArray>(sparsity)); | |||||
| } | |||||
| [Ignore("TODO implement nn_impl.zero_fraction")] | |||||
| [TestMethod] | |||||
| public void testUnknownSize() | |||||
| { | |||||
| var value = array_ops.placeholder(dtype: dtypes.float32); | |||||
| var sparsity = nn_impl.zero_fraction(value); | |||||
| using (var sess = self.cached_session()) | |||||
| { | |||||
| // TODO: make this compile | |||||
| //self.assertAllClose( | |||||
| // 0.25, | |||||
| // sess.run(sparsity, {value: [[0., 1.], [0.3, 2.]]})); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -73,6 +73,7 @@ namespace TensorFlowNET.UnitTest | |||||
| tf.peak_default_graph().Should().BeNull(); | tf.peak_default_graph().Should().BeNull(); | ||||
| var beforehand = tf.get_default_graph(); //this should create default automatically. | var beforehand = tf.get_default_graph(); //this should create default automatically. | ||||
| beforehand.graph_key.Should().NotContain("-0/", "Already created a graph in an other thread."); | beforehand.graph_key.Should().NotContain("-0/", "Already created a graph in an other thread."); | ||||
| beforehand.as_default(); | |||||
| tf.peak_default_graph().Should().NotBeNull(); | tf.peak_default_graph().Should().NotBeNull(); | ||||
| using (var sess = tf.Session()) | using (var sess = tf.Session()) | ||||
| @@ -1,4 +1,5 @@ | |||||
| using System.IO; | |||||
| using System; | |||||
| using System.IO; | |||||
| namespace TensorFlowNET.UnitTest | namespace TensorFlowNET.UnitTest | ||||
| { | { | ||||
| @@ -6,8 +7,16 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| public static string GetFullPathFromDataDir(string fileName) | public static string GetFullPathFromDataDir(string fileName) | ||||
| { | { | ||||
| var dir = Path.Combine(Directory.GetCurrentDirectory(), "..", "..", "..", "..", "..", "data"); | |||||
| return Path.GetFullPath(Path.Combine(dir, fileName)); | |||||
| var dataDir = GetRootContentDir(Directory.GetCurrentDirectory()); | |||||
| return Path.Combine(dataDir, fileName); | |||||
| } | |||||
| static string GetRootContentDir(string dir) | |||||
| { | |||||
| var path = Path.GetFullPath(Path.Combine(dir, "data")); | |||||
| if (Directory.Exists(path)) | |||||
| return path; | |||||
| return GetRootContentDir(Path.GetFullPath(Path.Combine(dir, ".."))); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,875 +0,0 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using Newtonsoft.Json.Linq; | |||||
| using NumSharp; | |||||
| using System; | |||||
| using System.Collections; | |||||
| using System.Collections.Generic; | |||||
| using Tensorflow; | |||||
| using Tensorflow.UnitTest; | |||||
| using Tensorflow.Util; | |||||
| using static Tensorflow.Binding; | |||||
| namespace TensorFlowNET.UnitTest.nest_test | |||||
| { | |||||
| /// <summary> | |||||
| /// excerpt of tensorflow/python/framework/util/nest_test.py | |||||
| /// </summary> | |||||
| [TestClass] | |||||
| public class NestTest : GraphModeTestBase | |||||
| { | |||||
| [TestInitialize] | |||||
| public void TestInitialize() | |||||
| { | |||||
| tf.Graph().as_default(); | |||||
| } | |||||
| //public class PointXY | |||||
| //{ | |||||
| // public double x; | |||||
| // public double y; | |||||
| //} | |||||
| // if attr: | |||||
| // class BadAttr(object): | |||||
| // """Class that has a non-iterable __attrs_attrs__.""" | |||||
| // __attrs_attrs__ = None | |||||
| // @attr.s | |||||
| // class SampleAttr(object): | |||||
| // field1 = attr.ib() | |||||
| // field2 = attr.ib() | |||||
| // @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
| // def testAttrsFlattenAndPack(self) : | |||||
| // if attr is None: | |||||
| // self.skipTest("attr module is unavailable.") | |||||
| // field_values = [1, 2] | |||||
| // sample_attr = NestTest.SampleAttr(* field_values) | |||||
| // self.assertFalse(nest._is_attrs(field_values)) | |||||
| // self.assertTrue(nest._is_attrs(sample_attr)) | |||||
| // flat = nest.flatten(sample_attr) | |||||
| // self.assertEqual(field_values, flat) | |||||
| // restructured_from_flat = nest.pack_sequence_as(sample_attr, flat) | |||||
| // self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr) | |||||
| // self.assertEqual(restructured_from_flat, sample_attr) | |||||
| //# Check that flatten fails if attributes are not iterable | |||||
| // with self.assertRaisesRegexp(TypeError, "object is not iterable"): | |||||
| // flat = nest.flatten(NestTest.BadAttr()) | |||||
| [Ignore] | |||||
| [TestMethod] | |||||
| public void testFlattenAndPack() | |||||
| { | |||||
| object structure = new object[] { new object[] { 3, 4 }, 5, new object[] { 6, 7, new object[] { 9, 10 }, 8 } }; | |||||
| var flat = new List<object> { "a", "b", "c", "d", "e", "f", "g", "h" }; | |||||
| self.assertEqual(nest.flatten(structure), new[] { 3, 4, 5, 6, 7, 9, 10, 8 }); | |||||
| self.assertEqual(JArray.FromObject(nest.pack_sequence_as(structure, flat)).ToString(), | |||||
| JArray.FromObject(new object[] { new object[] { "a", "b" }, "c", new object[] { "d", "e", new object[] { "f", "g" }, "h" } }).ToString()); | |||||
| structure = new object[] { new Hashtable { ["x"] = 4, ["y"] = 2 }, new object[] { new object[] { new Hashtable { ["x"] = 1, ["y"] = 0 }, }, } }; | |||||
| flat = new List<object> { 4, 2, 1, 0 }; | |||||
| self.assertEqual(nest.flatten(structure), flat); | |||||
| var restructured_from_flat = nest.pack_sequence_as(structure, flat) as object[]; | |||||
| //Console.WriteLine(JArray.FromObject(restructured_from_flat)); | |||||
| self.assertEqual(restructured_from_flat, structure); | |||||
| self.assertEqual((restructured_from_flat[0] as Hashtable)["x"], 4); | |||||
| self.assertEqual((restructured_from_flat[0] as Hashtable)["y"], 2); | |||||
| self.assertEqual((((restructured_from_flat[1] as object[])[0] as object[])[0] as Hashtable)["x"], 1); | |||||
| self.assertEqual((((restructured_from_flat[1] as object[])[0] as object[])[0] as Hashtable)["y"], 0); | |||||
| self.assertEqual(new List<object> { 5 }, nest.flatten(5)); | |||||
| var flat1 = nest.flatten(np.array(new[] { 5 })); | |||||
| self.assertEqual(new object[] { np.array(new int[] { 5 }) }, flat1); | |||||
| self.assertEqual("a", nest.pack_sequence_as(5, new List<object> { "a" })); | |||||
| self.assertEqual(np.array(new[] { 5 }), | |||||
| nest.pack_sequence_as("scalar", new List<object> { np.array(new[] { 5 }) })); | |||||
| Assert.ThrowsException<ValueError>(() => nest.pack_sequence_as("scalar", new List<object>() { 4, 5 })); | |||||
| Assert.ThrowsException<ValueError>(() => | |||||
| nest.pack_sequence_as(new object[] { 5, 6, new object[] { 7, 8 } }, new List<object> { "a", "b", "c" })); | |||||
| } | |||||
| // @parameterized.parameters({"mapping_type": collections.OrderedDict | |||||
| // }, | |||||
| // {"mapping_type": _CustomMapping | |||||
| //}) | |||||
| // @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
| // def testFlattenDictOrder(self, mapping_type) : | |||||
| // """`flatten` orders dicts by key, including OrderedDicts.""" | |||||
| // ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) | |||||
| // plain = {"d": 3, "b": 1, "a": 0, "c": 2} | |||||
| // ordered_flat = nest.flatten(ordered) | |||||
| // plain_flat = nest.flatten(plain) | |||||
| // self.assertEqual([0, 1, 2, 3], ordered_flat) | |||||
| // self.assertEqual([0, 1, 2, 3], plain_flat) | |||||
| // @parameterized.parameters({"mapping_type": collections.OrderedDict}, | |||||
| // {"mapping_type": _CustomMapping}) | |||||
| // def testPackDictOrder(self, mapping_type): | |||||
| // """Packing orders dicts by key, including OrderedDicts.""" | |||||
| // custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)]) | |||||
| // plain = {"d": 0, "b": 0, "a": 0, "c": 0} | |||||
| // seq = [0, 1, 2, 3] | |||||
| //custom_reconstruction = nest.pack_sequence_as(custom, seq) | |||||
| //plain_reconstruction = nest.pack_sequence_as(plain, seq) | |||||
| // self.assertIsInstance(custom_reconstruction, mapping_type) | |||||
| // self.assertIsInstance(plain_reconstruction, dict) | |||||
| // self.assertEqual( | |||||
| // mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]), | |||||
| // custom_reconstruction) | |||||
| // self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction) | |||||
| // Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name | |||||
| // @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
| // def testFlattenAndPack_withDicts(self) : | |||||
| // # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. | |||||
| // mess = [ | |||||
| // "z", | |||||
| // NestTest.Abc(3, 4), { | |||||
| // "d": _CustomMapping({ | |||||
| // 41: 4 | |||||
| // }), | |||||
| // "c": [ | |||||
| // 1, | |||||
| // collections.OrderedDict([ | |||||
| // ("b", 3), | |||||
| // ("a", 2), | |||||
| // ]), | |||||
| // ], | |||||
| // "b": 5 | |||||
| // }, 17 | |||||
| // ] | |||||
| // flattened = nest.flatten(mess) | |||||
| // self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17]) | |||||
| // structure_of_mess = [ | |||||
| // 14, | |||||
| // NestTest.Abc("a", True), | |||||
| // { | |||||
| // "d": _CustomMapping({ | |||||
| // 41: 42 | |||||
| // }), | |||||
| // "c": [ | |||||
| // 0, | |||||
| // collections.OrderedDict([ | |||||
| // ("b", 9), | |||||
| // ("a", 8), | |||||
| // ]), | |||||
| // ], | |||||
| // "b": 3 | |||||
| // }, | |||||
| // "hi everybody", | |||||
| // ] | |||||
| // unflattened = nest.pack_sequence_as(structure_of_mess, flattened) | |||||
| // self.assertEqual(unflattened, mess) | |||||
| // # Check also that the OrderedDict was created, with the correct key order. | |||||
| //unflattened_ordered_dict = unflattened[2]["c"][1] | |||||
| // self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict) | |||||
| // self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"]) | |||||
| // unflattened_custom_mapping = unflattened[2]["d"] | |||||
| // self.assertIsInstance(unflattened_custom_mapping, _CustomMapping) | |||||
| // self.assertEqual(list(unflattened_custom_mapping.keys()), [41]) | |||||
| [TestMethod] | |||||
| public void testFlatten_numpyIsNotFlattened() | |||||
| { | |||||
| var structure = np.array(1, 2, 3); | |||||
| var flattened = nest.flatten(structure); | |||||
| self.assertEqual(len(flattened), 1); | |||||
| } | |||||
| [TestMethod] | |||||
| public void testFlatten_stringIsNotFlattened() | |||||
| { | |||||
| var structure = "lots of letters"; | |||||
| var flattened = nest.flatten(structure); | |||||
| self.assertEqual(len(flattened), 1); | |||||
| var unflattened = nest.pack_sequence_as("goodbye", flattened); | |||||
| self.assertEqual(structure, unflattened); | |||||
| } | |||||
| // def testPackSequenceAs_notIterableError(self) : | |||||
| // with self.assertRaisesRegexp(TypeError, | |||||
| // "flat_sequence must be a sequence"): | |||||
| // nest.pack_sequence_as("hi", "bye") | |||||
| [TestMethod] | |||||
| public void testPackSequenceAs_wrongLengthsError() | |||||
| { | |||||
| Assert.ThrowsException<ValueError>(() => | |||||
| { | |||||
| // with self.assertRaisesRegexp( | |||||
| // ValueError, | |||||
| // "Structure had 2 elements, but flat_sequence had 3 elements."): | |||||
| nest.pack_sequence_as(new object[] { "hello", "world" }, new object[] { "and", "goodbye", "again" }); | |||||
| }); | |||||
| } | |||||
| [Ignore] | |||||
| [TestMethod] | |||||
| public void testIsSequence() | |||||
| { | |||||
| self.assertFalse(nest.is_sequence("1234")); | |||||
| self.assertTrue(nest.is_sequence(new object[] { 1, 3, new object[] { 4, 5 } })); | |||||
| // TODO: ValueTuple<T,T> | |||||
| //self.assertTrue(nest.is_sequence(((7, 8), (5, 6)))); | |||||
| self.assertTrue(nest.is_sequence(new object[] { })); | |||||
| self.assertTrue(nest.is_sequence(new Hashtable { ["a"] = 1, ["b"] = 2 })); | |||||
| self.assertFalse(nest.is_sequence(new HashSet<int> { 1, 2 })); | |||||
| var ones = array_ops.ones(new int[] { 2, 3 }); | |||||
| self.assertFalse(nest.is_sequence(ones)); | |||||
| self.assertFalse(nest.is_sequence(gen_math_ops.tanh(ones))); | |||||
| self.assertFalse(nest.is_sequence(np.ones(new int[] { 4, 5 }))); | |||||
| } | |||||
| // @parameterized.parameters({"mapping_type": _CustomMapping}, | |||||
| // {"mapping_type": dict}) | |||||
| // def testFlattenDictItems(self, mapping_type): | |||||
| // dictionary = mapping_type({ (4, 5, (6, 8)): ("a", "b", ("c", "d"))}) | |||||
| // flat = {4: "a", 5: "b", 6: "c", 8: "d"} | |||||
| // self.assertEqual(nest.flatten_dict_items(dictionary), flat) | |||||
| // with self.assertRaises(TypeError): | |||||
| // nest.flatten_dict_items(4) | |||||
| // bad_dictionary = mapping_type({ (4, 5, (4, 8)): ("a", "b", ("c", "d"))}) | |||||
| // with self.assertRaisesRegexp(ValueError, "not unique"): | |||||
| // nest.flatten_dict_items(bad_dictionary) | |||||
| // another_bad_dictionary = mapping_type({ | |||||
| // (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e"))) | |||||
| // }) | |||||
| // with self.assertRaisesRegexp( | |||||
| // ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"): | |||||
| // nest.flatten_dict_items(another_bad_dictionary) | |||||
| //# pylint does not correctly recognize these as class names and | |||||
| //# suggests to use variable style under_score naming. | |||||
| //# pylint: disable=invalid-name | |||||
| // Named0ab = collections.namedtuple("named_0", ("a", "b")) | |||||
| // Named1ab = collections.namedtuple("named_1", ("a", "b")) | |||||
| // SameNameab = collections.namedtuple("same_name", ("a", "b")) | |||||
| // SameNameab2 = collections.namedtuple("same_name", ("a", "b")) | |||||
| // SameNamexy = collections.namedtuple("same_name", ("x", "y")) | |||||
| // SameName1xy = collections.namedtuple("same_name_1", ("x", "y")) | |||||
| // SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y")) | |||||
| // NotSameName = collections.namedtuple("not_same_name", ("a", "b")) | |||||
| // # pylint: enable=invalid-name | |||||
| // class SameNamedType1(SameNameab): | |||||
| // pass | |||||
| // @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
| // def testAssertSameStructure(self): | |||||
| // structure1 = (((1, 2), 3), 4, (5, 6)) | |||||
| // structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) | |||||
| // structure_different_num_elements = ("spam", "eggs") | |||||
| // structure_different_nesting = (((1, 2), 3), 4, 5, (6,)) | |||||
| // nest.assert_same_structure(structure1, structure2) | |||||
| // nest.assert_same_structure("abc", 1.0) | |||||
| // nest.assert_same_structure("abc", np.array([0, 1])) | |||||
| // nest.assert_same_structure("abc", constant_op.constant([0, 1])) | |||||
| // with self.assertRaisesRegexp( | |||||
| // ValueError, | |||||
| // ("The two structures don't have the same nested structure\\.\n\n" | |||||
| // "First structure:.*?\n\n" | |||||
| // "Second structure:.*\n\n" | |||||
| // "More specifically: Substructure " | |||||
| // r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while ' | |||||
| // 'substructure "type=str str=spam" is not\n' | |||||
| // "Entire first structure:\n" | |||||
| // r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n" | |||||
| // "Entire second structure:\n" | |||||
| // r"\(\., \.\)")): | |||||
| // nest.assert_same_structure(structure1, structure_different_num_elements) | |||||
| // with self.assertRaisesRegexp( | |||||
| // ValueError, | |||||
| // ("The two structures don't have the same nested structure\\.\n\n" | |||||
| // "First structure:.*?\n\n" | |||||
| // "Second structure:.*\n\n" | |||||
| // r'More specifically: Substructure "type=list str=\[0, 1\]" ' | |||||
| // r'is a sequence, while substructure "type=ndarray str=\[0 1\]" ' | |||||
| // "is not")): | |||||
| // nest.assert_same_structure([0, 1], np.array([0, 1])) | |||||
| // with self.assertRaisesRegexp( | |||||
| // ValueError, | |||||
| // ("The two structures don't have the same nested structure\\.\n\n" | |||||
| // "First structure:.*?\n\n" | |||||
| // "Second structure:.*\n\n" | |||||
| // r'More specifically: Substructure "type=list str=\[0, 1\]" ' | |||||
| // 'is a sequence, while substructure "type=int str=0" ' | |||||
| // "is not")): | |||||
| // nest.assert_same_structure(0, [0, 1]) | |||||
| // self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1]) | |||||
| // with self.assertRaisesRegexp( | |||||
| // ValueError, | |||||
| // ("don't have the same nested structure\\.\n\n" | |||||
| // "First structure: .*?\n\nSecond structure: ")): | |||||
| // nest.assert_same_structure(structure1, structure_different_nesting) | |||||
| // self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), | |||||
| // NestTest.Named0ab("a", "b")) | |||||
| // nest.assert_same_structure(NestTest.Named0ab(3, 4), | |||||
| // NestTest.Named0ab("a", "b")) | |||||
| // self.assertRaises(TypeError, nest.assert_same_structure, | |||||
| // NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4)) | |||||
| // with self.assertRaisesRegexp( | |||||
| // ValueError, | |||||
| // ("don't have the same nested structure\\.\n\n" | |||||
| // "First structure: .*?\n\nSecond structure: ")): | |||||
| // nest.assert_same_structure(NestTest.Named0ab(3, 4), | |||||
| // NestTest.Named0ab([3], 4)) | |||||
| // with self.assertRaisesRegexp( | |||||
| // ValueError, | |||||
| // ("don't have the same nested structure\\.\n\n" | |||||
| // "First structure: .*?\n\nSecond structure: ")): | |||||
| // nest.assert_same_structure([[3], 4], [3, [4]]) | |||||
| // structure1_list = [[[1, 2], 3], 4, [5, 6]] | |||||
| // with self.assertRaisesRegexp(TypeError, | |||||
| // "don't have the same sequence type"): | |||||
| // nest.assert_same_structure(structure1, structure1_list) | |||||
| // nest.assert_same_structure(structure1, structure2, check_types= False) | |||||
| // nest.assert_same_structure(structure1, structure1_list, check_types=False) | |||||
| // with self.assertRaisesRegexp(ValueError, | |||||
| // "don't have the same set of keys"): | |||||
| // nest.assert_same_structure({"a": 1}, {"b": 1}) | |||||
| // nest.assert_same_structure(NestTest.SameNameab(0, 1), | |||||
| // NestTest.SameNameab2(2, 3)) | |||||
| // # This assertion is expected to pass: two namedtuples with the same | |||||
| // # name and field names are considered to be identical. | |||||
| // nest.assert_same_structure( | |||||
| // NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2), | |||||
| // NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4)) | |||||
| // expected_message = "The two structures don't have the same.*" | |||||
| // with self.assertRaisesRegexp(ValueError, expected_message): | |||||
| // nest.assert_same_structure( | |||||
| // NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)), | |||||
| // NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2)) | |||||
| // self.assertRaises(TypeError, nest.assert_same_structure, | |||||
| // NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3)) | |||||
| // self.assertRaises(TypeError, nest.assert_same_structure, | |||||
| // NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3)) | |||||
| // self.assertRaises(TypeError, nest.assert_same_structure, | |||||
| // NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3)) | |||||
| // EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name | |||||
| // def testHeterogeneousComparison(self): | |||||
| // nest.assert_same_structure({"a": 4}, _CustomMapping(a= 3)) | |||||
| // nest.assert_same_structure(_CustomMapping(b=3), {"b": 4}) | |||||
| [Ignore] | |||||
| [TestMethod] | |||||
| public void testMapStructure() | |||||
| { | |||||
| var structure1 = new object[] { new object[] { new object[] { 1, 2 }, 3 }, 4, new object[] { 5, 6 } }; | |||||
| var structure2 = new object[] { new object[] { new object[] { 7, 8 }, 9 }, 10, new object[] { 11, 12 } }; | |||||
| var structure1_plus1 = nest.map_structure(x => (int)x + 1, structure1); | |||||
| var structure1_strings = nest.map_structure(x => $"{x}", structure1); | |||||
| var s = JArray.FromObject(structure1_plus1).ToString(); | |||||
| Console.WriteLine(s); | |||||
| // nest.assert_same_structure(structure1, structure1_plus1) | |||||
| self.assertAllEqual(nest.flatten(structure1_plus1), new object[] { 2, 3, 4, 5, 6, 7 }); | |||||
| self.assertAllEqual(nest.flatten(structure1_strings), new object[] { "1", "2", "3", "4", "5", "6" }); | |||||
| var structure1_plus_structure2 = nest.map_structure(x => (int)(x[0]) + (int)(x[1]), structure1, structure2); | |||||
| self.assertEqual( | |||||
| new object[] { new object[] { new object[] { 1 + 7, 2 + 8 }, 3 + 9 }, 4 + 10, new object[] { 5 + 11, 6 + 12 } }, | |||||
| structure1_plus_structure2); | |||||
| // self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4)) | |||||
| // self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4)) | |||||
| // # Empty structures | |||||
| // self.assertEqual((), nest.map_structure(lambda x: x + 1, ())) | |||||
| // self.assertEqual([], nest.map_structure(lambda x: x + 1, [])) | |||||
| // self.assertEqual({}, nest.map_structure(lambda x: x + 1, {})) | |||||
| // self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1, | |||||
| // NestTest.EmptyNT())) | |||||
| // # This is checking actual equality of types, empty list != empty tuple | |||||
| // self.assertNotEqual((), nest.map_structure(lambda x: x + 1, [])) | |||||
| // with self.assertRaisesRegexp(TypeError, "callable"): | |||||
| // nest.map_structure("bad", structure1_plus1) | |||||
| // with self.assertRaisesRegexp(ValueError, "at least one structure"): | |||||
| // nest.map_structure(lambda x: x) | |||||
| // with self.assertRaisesRegexp(ValueError, "same number of elements"): | |||||
| // nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5)) | |||||
| // with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||||
| // nest.map_structure(lambda x, y: None, 3, (3,)) | |||||
| // with self.assertRaisesRegexp(TypeError, "same sequence type"): | |||||
| // nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5]) | |||||
| // with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||||
| // nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5))) | |||||
| // structure1_list = [[[1, 2], 3], 4, [5, 6]] | |||||
| // with self.assertRaisesRegexp(TypeError, "same sequence type"): | |||||
| // nest.map_structure(lambda x, y: None, structure1, structure1_list) | |||||
| // nest.map_structure(lambda x, y: None, structure1, structure1_list, | |||||
| // check_types=False) | |||||
| // with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||||
| // nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)), | |||||
| // check_types=False) | |||||
| // with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): | |||||
| // nest.map_structure(lambda x: None, structure1, foo="a") | |||||
| // with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): | |||||
| // nest.map_structure(lambda x: None, structure1, check_types=False, foo="a") | |||||
| // ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name | |||||
| } | |||||
| // @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
| // def testMapStructureWithStrings(self) : | |||||
| // inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz")) | |||||
| // inp_b = NestTest.ABTuple(a=2, b=(1, 3)) | |||||
| // out = nest.map_structure(lambda string, repeats: string* repeats, | |||||
| // inp_a, | |||||
| // inp_b) | |||||
| // self.assertEqual("foofoo", out.a) | |||||
| // self.assertEqual("bar", out.b[0]) | |||||
| // self.assertEqual("bazbazbaz", out.b[1]) | |||||
| // nt = NestTest.ABTuple(a=("something", "something_else"), | |||||
| // b="yet another thing") | |||||
| // rev_nt = nest.map_structure(lambda x: x[::- 1], nt) | |||||
| // # Check the output is the correct structure, and all strings are reversed. | |||||
| // nest.assert_same_structure(nt, rev_nt) | |||||
| // self.assertEqual(nt.a[0][::- 1], rev_nt.a[0]) | |||||
| // self.assertEqual(nt.a[1][::- 1], rev_nt.a[1]) | |||||
| // self.assertEqual(nt.b[::- 1], rev_nt.b) | |||||
| // @test_util.run_deprecated_v1 | |||||
| // def testMapStructureOverPlaceholders(self) : | |||||
| // inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), | |||||
| // array_ops.placeholder(dtypes.float32, shape=[3, 7])) | |||||
| // inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), | |||||
| // array_ops.placeholder(dtypes.float32, shape=[3, 7])) | |||||
| // output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b) | |||||
| // nest.assert_same_structure(output, inp_a) | |||||
| // self.assertShapeEqual(np.zeros((3, 4)), output[0]) | |||||
| // self.assertShapeEqual(np.zeros((3, 7)), output[1]) | |||||
| // feed_dict = { | |||||
| // inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)), | |||||
| // inp_b: (np.random.randn(3, 4), np.random.randn(3, 7)) | |||||
| // } | |||||
| // with self.cached_session() as sess: | |||||
| // output_np = sess.run(output, feed_dict=feed_dict) | |||||
| // self.assertAllClose(output_np[0], | |||||
| // feed_dict[inp_a][0] + feed_dict[inp_b][0]) | |||||
| // self.assertAllClose(output_np[1], | |||||
| // feed_dict[inp_a][1] + feed_dict[inp_b][1]) | |||||
| // def testAssertShallowStructure(self): | |||||
| // inp_ab = ["a", "b"] | |||||
| //inp_abc = ["a", "b", "c"] | |||||
| //expected_message = ( | |||||
| // "The two structures don't have the same sequence length. Input " | |||||
| // "structure has length 2, while shallow structure has length 3.") | |||||
| // with self.assertRaisesRegexp(ValueError, expected_message): | |||||
| // nest.assert_shallow_structure(inp_abc, inp_ab) | |||||
| // inp_ab1 = [(1, 1), (2, 2)] | |||||
| // inp_ab2 = [[1, 1], [2, 2]] | |||||
| // expected_message = ( | |||||
| // "The two structures don't have the same sequence type. Input structure " | |||||
| // "has type <(type|class) 'tuple'>, while shallow structure has type " | |||||
| // "<(type|class) 'list'>.") | |||||
| // with self.assertRaisesRegexp(TypeError, expected_message): | |||||
| // nest.assert_shallow_structure(inp_ab2, inp_ab1) | |||||
| // nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types= False) | |||||
| // inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}} | |||||
| // inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}} | |||||
| // expected_message = ( | |||||
| // r"The two structures don't have the same keys. Input " | |||||
| // r"structure has keys \['c'\], while shallow structure has " | |||||
| // r"keys \['d'\].") | |||||
| // with self.assertRaisesRegexp(ValueError, expected_message): | |||||
| // nest.assert_shallow_structure(inp_ab2, inp_ab1) | |||||
| // inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))]) | |||||
| // inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)]) | |||||
| // nest.assert_shallow_structure(inp_ab, inp_ba) | |||||
| // # This assertion is expected to pass: two namedtuples with the same | |||||
| //# name and field names are considered to be identical. | |||||
| //inp_shallow = NestTest.SameNameab(1, 2) | |||||
| // inp_deep = NestTest.SameNameab2(1, [1, 2, 3]) | |||||
| // nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False) | |||||
| // nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True) | |||||
| // def testFlattenUpTo(self): | |||||
| // # Shallow tree ends at scalar. | |||||
| // input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] | |||||
| // shallow_tree = [[True, True], [False, True]] | |||||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| // self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) | |||||
| // self.assertEqual(flattened_shallow_tree, [True, True, False, True]) | |||||
| //# Shallow tree ends at string. | |||||
| // input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] | |||||
| // shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] | |||||
| // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
| // input_tree) | |||||
| // input_tree_flattened = nest.flatten(input_tree) | |||||
| // self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
| // [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) | |||||
| // self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) | |||||
| // # Make sure dicts are correctly flattened, yielding values, not keys. | |||||
| //input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} | |||||
| // shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} | |||||
| // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
| // input_tree) | |||||
| // self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
| // [1, { "c": 2}, 3, (4, 5)]) | |||||
| // # Namedtuples. | |||||
| // ab_tuple = NestTest.ABTuple | |||||
| // input_tree = ab_tuple(a =[0, 1], b = 2) | |||||
| // shallow_tree = ab_tuple(a= 0, b= 1) | |||||
| // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
| // input_tree) | |||||
| // self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
| // [[0, 1], 2]) | |||||
| // # Nested dicts, OrderedDicts and namedtuples. | |||||
| // input_tree = collections.OrderedDict( | |||||
| // [("a", ab_tuple(a =[0, {"b": 1}], b=2)), | |||||
| // ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})]) | |||||
| // shallow_tree = input_tree | |||||
| // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
| // input_tree) | |||||
| // self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) | |||||
| // shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})]) | |||||
| // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
| // input_tree) | |||||
| // self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
| // [ab_tuple(a =[0, { "b": 1}], b=2), | |||||
| // 3, | |||||
| // collections.OrderedDict([("f", 4)])]) | |||||
| // shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) | |||||
| // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
| // input_tree) | |||||
| // self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
| // [ab_tuple(a =[0, {"b": 1}], b=2), | |||||
| // {"d": 3, "e": collections.OrderedDict([("f", 4)])}]) | |||||
| // ## Shallow non-list edge-case. | |||||
| // # Using iterable elements. | |||||
| // input_tree = ["input_tree"] | |||||
| //shallow_tree = "shallow_tree" | |||||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| // self.assertEqual(flattened_input_tree, [input_tree]) | |||||
| // self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
| // input_tree = ["input_tree_0", "input_tree_1"] | |||||
| //shallow_tree = "shallow_tree" | |||||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| // self.assertEqual(flattened_input_tree, [input_tree]) | |||||
| // self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
| // # Using non-iterable elements. | |||||
| //input_tree = [0] | |||||
| //shallow_tree = 9 | |||||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| // self.assertEqual(flattened_input_tree, [input_tree]) | |||||
| // self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
| // input_tree = [0, 1] | |||||
| //shallow_tree = 9 | |||||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| // self.assertEqual(flattened_input_tree, [input_tree]) | |||||
| // self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
| // ## Both non-list edge-case. | |||||
| //# Using iterable elements. | |||||
| //input_tree = "input_tree" | |||||
| // shallow_tree = "shallow_tree" | |||||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| // self.assertEqual(flattened_input_tree, [input_tree]) | |||||
| // self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
| // # Using non-iterable elements. | |||||
| //input_tree = 0 | |||||
| // shallow_tree = 0 | |||||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| // self.assertEqual(flattened_input_tree, [input_tree]) | |||||
| // self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
| // ## Input non-list edge-case. | |||||
| //# Using iterable elements. | |||||
| //input_tree = "input_tree" | |||||
| // shallow_tree = ["shallow_tree"] | |||||
| //expected_message = ("If shallow structure is a sequence, input must also " | |||||
| // "be a sequence. Input has type: <(type|class) 'str'>.") | |||||
| // with self.assertRaisesRegexp(TypeError, expected_message): | |||||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| // self.assertEqual(flattened_shallow_tree, shallow_tree) | |||||
| // input_tree = "input_tree" | |||||
| // shallow_tree = ["shallow_tree_9", "shallow_tree_8"] | |||||
| //with self.assertRaisesRegexp(TypeError, expected_message): | |||||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| // self.assertEqual(flattened_shallow_tree, shallow_tree) | |||||
| //# Using non-iterable elements. | |||||
| // input_tree = 0 | |||||
| // shallow_tree = [9] | |||||
| //expected_message = ("If shallow structure is a sequence, input must also " | |||||
| // "be a sequence. Input has type: <(type|class) 'int'>.") | |||||
| // with self.assertRaisesRegexp(TypeError, expected_message): | |||||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| // self.assertEqual(flattened_shallow_tree, shallow_tree) | |||||
| // input_tree = 0 | |||||
| // shallow_tree = [9, 8] | |||||
| //with self.assertRaisesRegexp(TypeError, expected_message): | |||||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| // self.assertEqual(flattened_shallow_tree, shallow_tree) | |||||
| // def testMapStructureUpTo(self) : | |||||
| // # Named tuples. | |||||
| // ab_tuple = collections.namedtuple("ab_tuple", "a, b") | |||||
| // op_tuple = collections.namedtuple("op_tuple", "add, mul") | |||||
| // inp_val = ab_tuple(a= 2, b= 3) | |||||
| // inp_ops = ab_tuple(a= op_tuple(add = 1, mul = 2), b= op_tuple(add = 2, mul = 3)) | |||||
| // out = nest.map_structure_up_to( | |||||
| // inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops) | |||||
| // self.assertEqual(out.a, 6) | |||||
| // self.assertEqual(out.b, 15) | |||||
| // # Lists. | |||||
| // data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] | |||||
| // name_list = ["evens", ["odds", "primes"]] | |||||
| // out = nest.map_structure_up_to( | |||||
| // name_list, lambda name, sec: "first_{}_{}".format(len(sec), name), | |||||
| // name_list, data_list) | |||||
| // self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]]) | |||||
| // # Dicts. | |||||
| // inp_val = dict(a= 2, b= 3) | |||||
| // inp_ops = dict(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3)) | |||||
| // out = nest.map_structure_up_to( | |||||
| // inp_val, | |||||
| // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||||
| // self.assertEqual(out["a"], 6) | |||||
| // self.assertEqual(out["b"], 15) | |||||
| // # Non-equal dicts. | |||||
| // inp_val = dict(a= 2, b= 3) | |||||
| // inp_ops = dict(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3)) | |||||
| // with self.assertRaisesRegexp(ValueError, "same keys"): | |||||
| // nest.map_structure_up_to( | |||||
| // inp_val, | |||||
| // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||||
| // # Dict+custom mapping. | |||||
| // inp_val = dict(a= 2, b= 3) | |||||
| // inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3)) | |||||
| // out = nest.map_structure_up_to( | |||||
| // inp_val, | |||||
| // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||||
| // self.assertEqual(out["a"], 6) | |||||
| // self.assertEqual(out["b"], 15) | |||||
| // # Non-equal dict/mapping. | |||||
| // inp_val = dict(a= 2, b= 3) | |||||
| // inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3)) | |||||
| // with self.assertRaisesRegexp(ValueError, "same keys"): | |||||
| // nest.map_structure_up_to( | |||||
| // inp_val, | |||||
| // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||||
| // def testGetTraverseShallowStructure(self): | |||||
| // scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []] | |||||
| // scalar_traverse_r = nest.get_traverse_shallow_structure( | |||||
| // lambda s: not isinstance(s, tuple), | |||||
| // scalar_traverse_input) | |||||
| // self.assertEqual(scalar_traverse_r, | |||||
| // [True, True, False, [True, True], {"a": False}, []]) | |||||
| // nest.assert_shallow_structure(scalar_traverse_r, | |||||
| // scalar_traverse_input) | |||||
| // structure_traverse_input = [(1, [2]), ([1], 2)] | |||||
| // structure_traverse_r = nest.get_traverse_shallow_structure( | |||||
| // lambda s: (True, False) if isinstance(s, tuple) else True, | |||||
| // structure_traverse_input) | |||||
| // self.assertEqual(structure_traverse_r, | |||||
| // [(True, False), ([True], False)]) | |||||
| // nest.assert_shallow_structure(structure_traverse_r, | |||||
| // structure_traverse_input) | |||||
| // with self.assertRaisesRegexp(TypeError, "returned structure"): | |||||
| // nest.get_traverse_shallow_structure(lambda _: [True], 0) | |||||
| // with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"): | |||||
| // nest.get_traverse_shallow_structure(lambda _: 1, [1]) | |||||
| // with self.assertRaisesRegexp( | |||||
| // TypeError, "didn't return a depth=1 structure of bools"): | |||||
| // nest.get_traverse_shallow_structure(lambda _: [1], [1]) | |||||
| // def testYieldFlatStringPaths(self): | |||||
| // for inputs_expected in ({"inputs": [], "expected": []}, | |||||
| // {"inputs": 3, "expected": [()]}, | |||||
| // {"inputs": [3], "expected": [(0,)]}, | |||||
| // {"inputs": {"a": 3}, "expected": [("a",)]}, | |||||
| // {"inputs": {"a": {"b": 4}}, | |||||
| // "expected": [("a", "b")]}, | |||||
| // {"inputs": [{"a": 2}], "expected": [(0, "a")]}, | |||||
| // {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]}, | |||||
| // {"inputs": [{"a": [(23, 42)]}], | |||||
| // "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]}, | |||||
| // {"inputs": [{"a": ([23], 42)}], | |||||
| // "expected": [(0, "a", 0, 0), (0, "a", 1)]}, | |||||
| // {"inputs": {"a": {"a": 2}, "c": [[[4]]]}, | |||||
| // "expected": [("a", "a"), ("c", 0, 0, 0)]}, | |||||
| // {"inputs": {"0": [{"1": 23}]}, | |||||
| // "expected": [("0", 0, "1")]}): | |||||
| // inputs = inputs_expected["inputs"] | |||||
| // expected = inputs_expected["expected"] | |||||
| // self.assertEqual(list(nest.yield_flat_paths(inputs)), expected) | |||||
| // def testFlattenWithStringPaths(self): | |||||
| // for inputs_expected in ( | |||||
| // {"inputs": [], "expected": []}, | |||||
| // {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]}, | |||||
| // {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}): | |||||
| // inputs = inputs_expected["inputs"] | |||||
| // expected = inputs_expected["expected"] | |||||
| // self.assertEqual( | |||||
| // nest.flatten_with_joined_string_paths(inputs, separator="/"), | |||||
| // expected) | |||||
| // # Need a separate test for namedtuple as we can't declare tuple definitions | |||||
| // # in the @parameterized arguments. | |||||
| // def testFlattenNamedTuple(self): | |||||
| // # pylint: disable=invalid-name | |||||
| // Foo = collections.namedtuple("Foo", ["a", "b"]) | |||||
| // Bar = collections.namedtuple("Bar", ["c", "d"]) | |||||
| // # pylint: enable=invalid-name | |||||
| // test_cases = [ | |||||
| // (Foo(a = 3, b = Bar(c = 23, d = 42)), | |||||
| // [("a", 3), ("b/c", 23), ("b/d", 42)]), | |||||
| // (Foo(a = Bar(c = 23, d = 42), b = Bar(c = 0, d = "something")), | |||||
| // [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]), | |||||
| // (Bar(c = 42, d = 43), | |||||
| // [("c", 42), ("d", 43)]), | |||||
| // (Bar(c =[42], d = 43), | |||||
| // [("c/0", 42), ("d", 43)]), | |||||
| // ] | |||||
| // for inputs, expected in test_cases: | |||||
| // self.assertEqual( | |||||
| // list(nest.flatten_with_joined_string_paths(inputs)), expected) | |||||
| // @parameterized.named_parameters( | |||||
| // ("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))), | |||||
| // ("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True, | |||||
| // {"a": ("a", 4), "b": ("b", 6)}), | |||||
| // ("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))), | |||||
| // ("nested", | |||||
| // {"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True, | |||||
| // {"a": [("a/0", 10), ("a/1", 12)], | |||||
| // "b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]})) | |||||
| // def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected): | |||||
| // def format_sum(path, * values): | |||||
| // return (path, sum(values)) | |||||
| // result = nest.map_structure_with_paths(format_sum, s1, s2, | |||||
| // check_types=check_types) | |||||
| // self.assertEqual(expected, result) | |||||
| // @parameterized.named_parameters( | |||||
| // ("tuples", (1, 2), (3, 4, 5), ValueError), | |||||
| // ("dicts", {"a": 1}, {"b": 2}, ValueError), | |||||
| // ("mixed", (1, 2), [3, 4], TypeError), | |||||
| // ("nested", | |||||
| // {"a": [2, 3], "b": [1, 3]}, | |||||
| // {"b": [5, 6, 7], "a": [8, 9]}, | |||||
| // ValueError | |||||
| // )) | |||||
| // def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type): | |||||
| // with self.assertRaises(error_type): | |||||
| // nest.map_structure_with_paths(lambda path, * s: 0, s1, s2) | |||||
| //class NestBenchmark(test.Benchmark): | |||||
| // def run_and_report(self, s1, s2, name): | |||||
| // burn_iter, test_iter = 100, 30000 | |||||
| // for _ in xrange(burn_iter) : | |||||
| // nest.assert_same_structure(s1, s2) | |||||
| // t0 = time.time() | |||||
| // for _ in xrange(test_iter) : | |||||
| // nest.assert_same_structure(s1, s2) | |||||
| // t1 = time.time() | |||||
| // self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter, | |||||
| // name=name) | |||||
| // def benchmark_assert_structure(self): | |||||
| // s1 = (((1, 2), 3), 4, (5, 6)) | |||||
| // s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) | |||||
| // self.run_and_report(s1, s2, "assert_same_structure_6_elem") | |||||
| // s1 = (((1, 2), 3), 4, (5, 6)) * 10 | |||||
| // s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10 | |||||
| // self.run_and_report(s1, s2, "assert_same_structure_60_elem") | |||||
| //if __name__ == "__main__": | |||||
| // test.main() | |||||
| } | |||||
| } | |||||
| @@ -1,883 +0,0 @@ | |||||
| # Copyright 2016 The TensorFlow Authors. All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Tests for utilities working with arbitrarily nested structures.""" | |||||
| from __future__ import absolute_import | |||||
| from __future__ import division | |||||
| from __future__ import print_function | |||||
| import collections | |||||
| import time | |||||
| from absl.testing import parameterized | |||||
| import numpy as np | |||||
| from six.moves import xrange # pylint: disable=redefined-builtin | |||||
| from tensorflow.python.framework import constant_op | |||||
| from tensorflow.python.framework import dtypes | |||||
| from tensorflow.python.framework import test_util | |||||
| from tensorflow.python.ops import array_ops | |||||
| from tensorflow.python.ops import math_ops | |||||
| from tensorflow.python.platform import test | |||||
| from tensorflow.python.util import nest | |||||
| try: | |||||
| import attr # pylint:disable=g-import-not-at-top | |||||
| except ImportError: | |||||
| attr = None | |||||
| class _CustomMapping(collections.Mapping): | |||||
| def __init__(self, *args, **kwargs): | |||||
| self._wrapped = dict(*args, **kwargs) | |||||
| def __getitem__(self, key): | |||||
| return self._wrapped[key] | |||||
| def __iter__(self): | |||||
| return iter(self._wrapped) | |||||
| def __len__(self): | |||||
| return len(self._wrapped) | |||||
| class NestTest(parameterized.TestCase, test.TestCase): | |||||
| PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name | |||||
| if attr: | |||||
| class BadAttr(object): | |||||
| """Class that has a non-iterable __attrs_attrs__.""" | |||||
| __attrs_attrs__ = None | |||||
| @attr.s | |||||
| class SampleAttr(object): | |||||
| field1 = attr.ib() | |||||
| field2 = attr.ib() | |||||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
| def testAttrsFlattenAndPack(self): | |||||
| if attr is None: | |||||
| self.skipTest("attr module is unavailable.") | |||||
| field_values = [1, 2] | |||||
| sample_attr = NestTest.SampleAttr(*field_values) | |||||
| self.assertFalse(nest._is_attrs(field_values)) | |||||
| self.assertTrue(nest._is_attrs(sample_attr)) | |||||
| flat = nest.flatten(sample_attr) | |||||
| self.assertEqual(field_values, flat) | |||||
| restructured_from_flat = nest.pack_sequence_as(sample_attr, flat) | |||||
| self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr) | |||||
| self.assertEqual(restructured_from_flat, sample_attr) | |||||
| # Check that flatten fails if attributes are not iterable | |||||
| with self.assertRaisesRegexp(TypeError, "object is not iterable"): | |||||
| flat = nest.flatten(NestTest.BadAttr()) | |||||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
| def testFlattenAndPack(self): | |||||
| structure = ((3, 4), 5, (6, 7, (9, 10), 8)) | |||||
| flat = ["a", "b", "c", "d", "e", "f", "g", "h"] | |||||
| self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8]) | |||||
| self.assertEqual( | |||||
| nest.pack_sequence_as(structure, flat), (("a", "b"), "c", | |||||
| ("d", "e", ("f", "g"), "h"))) | |||||
| structure = (NestTest.PointXY(x=4, y=2), | |||||
| ((NestTest.PointXY(x=1, y=0),),)) | |||||
| flat = [4, 2, 1, 0] | |||||
| self.assertEqual(nest.flatten(structure), flat) | |||||
| restructured_from_flat = nest.pack_sequence_as(structure, flat) | |||||
| self.assertEqual(restructured_from_flat, structure) | |||||
| self.assertEqual(restructured_from_flat[0].x, 4) | |||||
| self.assertEqual(restructured_from_flat[0].y, 2) | |||||
| self.assertEqual(restructured_from_flat[1][0][0].x, 1) | |||||
| self.assertEqual(restructured_from_flat[1][0][0].y, 0) | |||||
| self.assertEqual([5], nest.flatten(5)) | |||||
| self.assertEqual([np.array([5])], nest.flatten(np.array([5]))) | |||||
| self.assertEqual("a", nest.pack_sequence_as(5, ["a"])) | |||||
| self.assertEqual( | |||||
| np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])])) | |||||
| with self.assertRaisesRegexp(ValueError, "Structure is a scalar"): | |||||
| nest.pack_sequence_as("scalar", [4, 5]) | |||||
| with self.assertRaisesRegexp(TypeError, "flat_sequence"): | |||||
| nest.pack_sequence_as([4, 5], "bad_sequence") | |||||
| with self.assertRaises(ValueError): | |||||
| nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"]) | |||||
| @parameterized.parameters({"mapping_type": collections.OrderedDict}, | |||||
| {"mapping_type": _CustomMapping}) | |||||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
| def testFlattenDictOrder(self, mapping_type): | |||||
| """`flatten` orders dicts by key, including OrderedDicts.""" | |||||
| ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) | |||||
| plain = {"d": 3, "b": 1, "a": 0, "c": 2} | |||||
| ordered_flat = nest.flatten(ordered) | |||||
| plain_flat = nest.flatten(plain) | |||||
| self.assertEqual([0, 1, 2, 3], ordered_flat) | |||||
| self.assertEqual([0, 1, 2, 3], plain_flat) | |||||
| @parameterized.parameters({"mapping_type": collections.OrderedDict}, | |||||
| {"mapping_type": _CustomMapping}) | |||||
| def testPackDictOrder(self, mapping_type): | |||||
| """Packing orders dicts by key, including OrderedDicts.""" | |||||
| custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)]) | |||||
| plain = {"d": 0, "b": 0, "a": 0, "c": 0} | |||||
| seq = [0, 1, 2, 3] | |||||
| custom_reconstruction = nest.pack_sequence_as(custom, seq) | |||||
| plain_reconstruction = nest.pack_sequence_as(plain, seq) | |||||
| self.assertIsInstance(custom_reconstruction, mapping_type) | |||||
| self.assertIsInstance(plain_reconstruction, dict) | |||||
| self.assertEqual( | |||||
| mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]), | |||||
| custom_reconstruction) | |||||
| self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction) | |||||
| Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name | |||||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
| def testFlattenAndPack_withDicts(self): | |||||
| # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. | |||||
| mess = [ | |||||
| "z", | |||||
| NestTest.Abc(3, 4), { | |||||
| "d": _CustomMapping({ | |||||
| 41: 4 | |||||
| }), | |||||
| "c": [ | |||||
| 1, | |||||
| collections.OrderedDict([ | |||||
| ("b", 3), | |||||
| ("a", 2), | |||||
| ]), | |||||
| ], | |||||
| "b": 5 | |||||
| }, 17 | |||||
| ] | |||||
| flattened = nest.flatten(mess) | |||||
| self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17]) | |||||
| structure_of_mess = [ | |||||
| 14, | |||||
| NestTest.Abc("a", True), | |||||
| { | |||||
| "d": _CustomMapping({ | |||||
| 41: 42 | |||||
| }), | |||||
| "c": [ | |||||
| 0, | |||||
| collections.OrderedDict([ | |||||
| ("b", 9), | |||||
| ("a", 8), | |||||
| ]), | |||||
| ], | |||||
| "b": 3 | |||||
| }, | |||||
| "hi everybody", | |||||
| ] | |||||
| unflattened = nest.pack_sequence_as(structure_of_mess, flattened) | |||||
| self.assertEqual(unflattened, mess) | |||||
| # Check also that the OrderedDict was created, with the correct key order. | |||||
| unflattened_ordered_dict = unflattened[2]["c"][1] | |||||
| self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict) | |||||
| self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"]) | |||||
| unflattened_custom_mapping = unflattened[2]["d"] | |||||
| self.assertIsInstance(unflattened_custom_mapping, _CustomMapping) | |||||
| self.assertEqual(list(unflattened_custom_mapping.keys()), [41]) | |||||
| def testFlatten_numpyIsNotFlattened(self): | |||||
| structure = np.array([1, 2, 3]) | |||||
| flattened = nest.flatten(structure) | |||||
| self.assertEqual(len(flattened), 1) | |||||
| def testFlatten_stringIsNotFlattened(self): | |||||
| structure = "lots of letters" | |||||
| flattened = nest.flatten(structure) | |||||
| self.assertEqual(len(flattened), 1) | |||||
| unflattened = nest.pack_sequence_as("goodbye", flattened) | |||||
| self.assertEqual(structure, unflattened) | |||||
| def testPackSequenceAs_notIterableError(self): | |||||
| with self.assertRaisesRegexp(TypeError, | |||||
| "flat_sequence must be a sequence"): | |||||
| nest.pack_sequence_as("hi", "bye") | |||||
| def testPackSequenceAs_wrongLengthsError(self): | |||||
| with self.assertRaisesRegexp( | |||||
| ValueError, | |||||
| "Structure had 2 elements, but flat_sequence had 3 elements."): | |||||
| nest.pack_sequence_as(["hello", "world"], | |||||
| ["and", "goodbye", "again"]) | |||||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
| def testIsSequence(self): | |||||
| self.assertFalse(nest.is_sequence("1234")) | |||||
| self.assertTrue(nest.is_sequence([1, 3, [4, 5]])) | |||||
| self.assertTrue(nest.is_sequence(((7, 8), (5, 6)))) | |||||
| self.assertTrue(nest.is_sequence([])) | |||||
| self.assertTrue(nest.is_sequence({"a": 1, "b": 2})) | |||||
| self.assertFalse(nest.is_sequence(set([1, 2]))) | |||||
| ones = array_ops.ones([2, 3]) | |||||
| self.assertFalse(nest.is_sequence(ones)) | |||||
| self.assertFalse(nest.is_sequence(math_ops.tanh(ones))) | |||||
| self.assertFalse(nest.is_sequence(np.ones((4, 5)))) | |||||
| @parameterized.parameters({"mapping_type": _CustomMapping}, | |||||
| {"mapping_type": dict}) | |||||
| def testFlattenDictItems(self, mapping_type): | |||||
| dictionary = mapping_type({(4, 5, (6, 8)): ("a", "b", ("c", "d"))}) | |||||
| flat = {4: "a", 5: "b", 6: "c", 8: "d"} | |||||
| self.assertEqual(nest.flatten_dict_items(dictionary), flat) | |||||
| with self.assertRaises(TypeError): | |||||
| nest.flatten_dict_items(4) | |||||
| bad_dictionary = mapping_type({(4, 5, (4, 8)): ("a", "b", ("c", "d"))}) | |||||
| with self.assertRaisesRegexp(ValueError, "not unique"): | |||||
| nest.flatten_dict_items(bad_dictionary) | |||||
| another_bad_dictionary = mapping_type({ | |||||
| (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e"))) | |||||
| }) | |||||
| with self.assertRaisesRegexp( | |||||
| ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"): | |||||
| nest.flatten_dict_items(another_bad_dictionary) | |||||
| # pylint does not correctly recognize these as class names and | |||||
| # suggests to use variable style under_score naming. | |||||
| # pylint: disable=invalid-name | |||||
| Named0ab = collections.namedtuple("named_0", ("a", "b")) | |||||
| Named1ab = collections.namedtuple("named_1", ("a", "b")) | |||||
| SameNameab = collections.namedtuple("same_name", ("a", "b")) | |||||
| SameNameab2 = collections.namedtuple("same_name", ("a", "b")) | |||||
| SameNamexy = collections.namedtuple("same_name", ("x", "y")) | |||||
| SameName1xy = collections.namedtuple("same_name_1", ("x", "y")) | |||||
| SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y")) | |||||
| NotSameName = collections.namedtuple("not_same_name", ("a", "b")) | |||||
| # pylint: enable=invalid-name | |||||
| class SameNamedType1(SameNameab): | |||||
| pass | |||||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
| def testAssertSameStructure(self): | |||||
| structure1 = (((1, 2), 3), 4, (5, 6)) | |||||
| structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) | |||||
| structure_different_num_elements = ("spam", "eggs") | |||||
| structure_different_nesting = (((1, 2), 3), 4, 5, (6,)) | |||||
| nest.assert_same_structure(structure1, structure2) | |||||
| nest.assert_same_structure("abc", 1.0) | |||||
| nest.assert_same_structure("abc", np.array([0, 1])) | |||||
| nest.assert_same_structure("abc", constant_op.constant([0, 1])) | |||||
| with self.assertRaisesRegexp( | |||||
| ValueError, | |||||
| ("The two structures don't have the same nested structure\\.\n\n" | |||||
| "First structure:.*?\n\n" | |||||
| "Second structure:.*\n\n" | |||||
| "More specifically: Substructure " | |||||
| r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while ' | |||||
| 'substructure "type=str str=spam" is not\n' | |||||
| "Entire first structure:\n" | |||||
| r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n" | |||||
| "Entire second structure:\n" | |||||
| r"\(\., \.\)")): | |||||
| nest.assert_same_structure(structure1, structure_different_num_elements) | |||||
| with self.assertRaisesRegexp( | |||||
| ValueError, | |||||
| ("The two structures don't have the same nested structure\\.\n\n" | |||||
| "First structure:.*?\n\n" | |||||
| "Second structure:.*\n\n" | |||||
| r'More specifically: Substructure "type=list str=\[0, 1\]" ' | |||||
| r'is a sequence, while substructure "type=ndarray str=\[0 1\]" ' | |||||
| "is not")): | |||||
| nest.assert_same_structure([0, 1], np.array([0, 1])) | |||||
| with self.assertRaisesRegexp( | |||||
| ValueError, | |||||
| ("The two structures don't have the same nested structure\\.\n\n" | |||||
| "First structure:.*?\n\n" | |||||
| "Second structure:.*\n\n" | |||||
| r'More specifically: Substructure "type=list str=\[0, 1\]" ' | |||||
| 'is a sequence, while substructure "type=int str=0" ' | |||||
| "is not")): | |||||
| nest.assert_same_structure(0, [0, 1]) | |||||
| self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1]) | |||||
| with self.assertRaisesRegexp( | |||||
| ValueError, | |||||
| ("don't have the same nested structure\\.\n\n" | |||||
| "First structure: .*?\n\nSecond structure: ")): | |||||
| nest.assert_same_structure(structure1, structure_different_nesting) | |||||
| self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), | |||||
| NestTest.Named0ab("a", "b")) | |||||
| nest.assert_same_structure(NestTest.Named0ab(3, 4), | |||||
| NestTest.Named0ab("a", "b")) | |||||
| self.assertRaises(TypeError, nest.assert_same_structure, | |||||
| NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4)) | |||||
| with self.assertRaisesRegexp( | |||||
| ValueError, | |||||
| ("don't have the same nested structure\\.\n\n" | |||||
| "First structure: .*?\n\nSecond structure: ")): | |||||
| nest.assert_same_structure(NestTest.Named0ab(3, 4), | |||||
| NestTest.Named0ab([3], 4)) | |||||
| with self.assertRaisesRegexp( | |||||
| ValueError, | |||||
| ("don't have the same nested structure\\.\n\n" | |||||
| "First structure: .*?\n\nSecond structure: ")): | |||||
| nest.assert_same_structure([[3], 4], [3, [4]]) | |||||
| structure1_list = [[[1, 2], 3], 4, [5, 6]] | |||||
| with self.assertRaisesRegexp(TypeError, | |||||
| "don't have the same sequence type"): | |||||
| nest.assert_same_structure(structure1, structure1_list) | |||||
| nest.assert_same_structure(structure1, structure2, check_types=False) | |||||
| nest.assert_same_structure(structure1, structure1_list, check_types=False) | |||||
| with self.assertRaisesRegexp(ValueError, | |||||
| "don't have the same set of keys"): | |||||
| nest.assert_same_structure({"a": 1}, {"b": 1}) | |||||
| nest.assert_same_structure(NestTest.SameNameab(0, 1), | |||||
| NestTest.SameNameab2(2, 3)) | |||||
| # This assertion is expected to pass: two namedtuples with the same | |||||
| # name and field names are considered to be identical. | |||||
| nest.assert_same_structure( | |||||
| NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2), | |||||
| NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4)) | |||||
| expected_message = "The two structures don't have the same.*" | |||||
| with self.assertRaisesRegexp(ValueError, expected_message): | |||||
| nest.assert_same_structure( | |||||
| NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)), | |||||
| NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2)) | |||||
| self.assertRaises(TypeError, nest.assert_same_structure, | |||||
| NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3)) | |||||
| self.assertRaises(TypeError, nest.assert_same_structure, | |||||
| NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3)) | |||||
| self.assertRaises(TypeError, nest.assert_same_structure, | |||||
| NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3)) | |||||
| EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name | |||||
| def testHeterogeneousComparison(self): | |||||
| nest.assert_same_structure({"a": 4}, _CustomMapping(a=3)) | |||||
| nest.assert_same_structure(_CustomMapping(b=3), {"b": 4}) | |||||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
| def testMapStructure(self): | |||||
| structure1 = (((1, 2), 3), 4, (5, 6)) | |||||
| structure2 = (((7, 8), 9), 10, (11, 12)) | |||||
| structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1) | |||||
| nest.assert_same_structure(structure1, structure1_plus1) | |||||
| self.assertAllEqual( | |||||
| [2, 3, 4, 5, 6, 7], | |||||
| nest.flatten(structure1_plus1)) | |||||
| structure1_plus_structure2 = nest.map_structure( | |||||
| lambda x, y: x + y, structure1, structure2) | |||||
| self.assertEqual( | |||||
| (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)), | |||||
| structure1_plus_structure2) | |||||
| self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4)) | |||||
| self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4)) | |||||
| # Empty structures | |||||
| self.assertEqual((), nest.map_structure(lambda x: x + 1, ())) | |||||
| self.assertEqual([], nest.map_structure(lambda x: x + 1, [])) | |||||
| self.assertEqual({}, nest.map_structure(lambda x: x + 1, {})) | |||||
| self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1, | |||||
| NestTest.EmptyNT())) | |||||
| # This is checking actual equality of types, empty list != empty tuple | |||||
| self.assertNotEqual((), nest.map_structure(lambda x: x + 1, [])) | |||||
| with self.assertRaisesRegexp(TypeError, "callable"): | |||||
| nest.map_structure("bad", structure1_plus1) | |||||
| with self.assertRaisesRegexp(ValueError, "at least one structure"): | |||||
| nest.map_structure(lambda x: x) | |||||
| with self.assertRaisesRegexp(ValueError, "same number of elements"): | |||||
| nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5)) | |||||
| with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||||
| nest.map_structure(lambda x, y: None, 3, (3,)) | |||||
| with self.assertRaisesRegexp(TypeError, "same sequence type"): | |||||
| nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5]) | |||||
| with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||||
| nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5))) | |||||
| structure1_list = [[[1, 2], 3], 4, [5, 6]] | |||||
| with self.assertRaisesRegexp(TypeError, "same sequence type"): | |||||
| nest.map_structure(lambda x, y: None, structure1, structure1_list) | |||||
| nest.map_structure(lambda x, y: None, structure1, structure1_list, | |||||
| check_types=False) | |||||
| with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||||
| nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)), | |||||
| check_types=False) | |||||
| with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): | |||||
| nest.map_structure(lambda x: None, structure1, foo="a") | |||||
| with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): | |||||
| nest.map_structure(lambda x: None, structure1, check_types=False, foo="a") | |||||
| ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name | |||||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
| def testMapStructureWithStrings(self): | |||||
| inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz")) | |||||
| inp_b = NestTest.ABTuple(a=2, b=(1, 3)) | |||||
| out = nest.map_structure(lambda string, repeats: string * repeats, | |||||
| inp_a, | |||||
| inp_b) | |||||
| self.assertEqual("foofoo", out.a) | |||||
| self.assertEqual("bar", out.b[0]) | |||||
| self.assertEqual("bazbazbaz", out.b[1]) | |||||
| nt = NestTest.ABTuple(a=("something", "something_else"), | |||||
| b="yet another thing") | |||||
| rev_nt = nest.map_structure(lambda x: x[::-1], nt) | |||||
| # Check the output is the correct structure, and all strings are reversed. | |||||
| nest.assert_same_structure(nt, rev_nt) | |||||
| self.assertEqual(nt.a[0][::-1], rev_nt.a[0]) | |||||
| self.assertEqual(nt.a[1][::-1], rev_nt.a[1]) | |||||
| self.assertEqual(nt.b[::-1], rev_nt.b) | |||||
| @test_util.run_deprecated_v1 | |||||
| def testMapStructureOverPlaceholders(self): | |||||
| inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), | |||||
| array_ops.placeholder(dtypes.float32, shape=[3, 7])) | |||||
| inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), | |||||
| array_ops.placeholder(dtypes.float32, shape=[3, 7])) | |||||
| output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b) | |||||
| nest.assert_same_structure(output, inp_a) | |||||
| self.assertShapeEqual(np.zeros((3, 4)), output[0]) | |||||
| self.assertShapeEqual(np.zeros((3, 7)), output[1]) | |||||
| feed_dict = { | |||||
| inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)), | |||||
| inp_b: (np.random.randn(3, 4), np.random.randn(3, 7)) | |||||
| } | |||||
| with self.cached_session() as sess: | |||||
| output_np = sess.run(output, feed_dict=feed_dict) | |||||
| self.assertAllClose(output_np[0], | |||||
| feed_dict[inp_a][0] + feed_dict[inp_b][0]) | |||||
| self.assertAllClose(output_np[1], | |||||
| feed_dict[inp_a][1] + feed_dict[inp_b][1]) | |||||
| def testAssertShallowStructure(self): | |||||
| inp_ab = ["a", "b"] | |||||
| inp_abc = ["a", "b", "c"] | |||||
| expected_message = ( | |||||
| "The two structures don't have the same sequence length. Input " | |||||
| "structure has length 2, while shallow structure has length 3.") | |||||
| with self.assertRaisesRegexp(ValueError, expected_message): | |||||
| nest.assert_shallow_structure(inp_abc, inp_ab) | |||||
| inp_ab1 = [(1, 1), (2, 2)] | |||||
| inp_ab2 = [[1, 1], [2, 2]] | |||||
| expected_message = ( | |||||
| "The two structures don't have the same sequence type. Input structure " | |||||
| "has type <(type|class) 'tuple'>, while shallow structure has type " | |||||
| "<(type|class) 'list'>.") | |||||
| with self.assertRaisesRegexp(TypeError, expected_message): | |||||
| nest.assert_shallow_structure(inp_ab2, inp_ab1) | |||||
| nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False) | |||||
| inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}} | |||||
| inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}} | |||||
| expected_message = ( | |||||
| r"The two structures don't have the same keys. Input " | |||||
| r"structure has keys \['c'\], while shallow structure has " | |||||
| r"keys \['d'\].") | |||||
| with self.assertRaisesRegexp(ValueError, expected_message): | |||||
| nest.assert_shallow_structure(inp_ab2, inp_ab1) | |||||
| inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))]) | |||||
| inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)]) | |||||
| nest.assert_shallow_structure(inp_ab, inp_ba) | |||||
| # This assertion is expected to pass: two namedtuples with the same | |||||
| # name and field names are considered to be identical. | |||||
| inp_shallow = NestTest.SameNameab(1, 2) | |||||
| inp_deep = NestTest.SameNameab2(1, [1, 2, 3]) | |||||
| nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False) | |||||
| nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True) | |||||
| def testFlattenUpTo(self): | |||||
| # Shallow tree ends at scalar. | |||||
| input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] | |||||
| shallow_tree = [[True, True], [False, True]] | |||||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) | |||||
| self.assertEqual(flattened_shallow_tree, [True, True, False, True]) | |||||
| # Shallow tree ends at string. | |||||
| input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] | |||||
| shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] | |||||
| input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
| input_tree) | |||||
| input_tree_flattened = nest.flatten(input_tree) | |||||
| self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
| [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) | |||||
| self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) | |||||
| # Make sure dicts are correctly flattened, yielding values, not keys. | |||||
| input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} | |||||
| shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} | |||||
| input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
| input_tree) | |||||
| self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
| [1, {"c": 2}, 3, (4, 5)]) | |||||
| # Namedtuples. | |||||
| ab_tuple = NestTest.ABTuple | |||||
| input_tree = ab_tuple(a=[0, 1], b=2) | |||||
| shallow_tree = ab_tuple(a=0, b=1) | |||||
| input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
| input_tree) | |||||
| self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
| [[0, 1], 2]) | |||||
| # Nested dicts, OrderedDicts and namedtuples. | |||||
| input_tree = collections.OrderedDict( | |||||
| [("a", ab_tuple(a=[0, {"b": 1}], b=2)), | |||||
| ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})]) | |||||
| shallow_tree = input_tree | |||||
| input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
| input_tree) | |||||
| self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) | |||||
| shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})]) | |||||
| input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
| input_tree) | |||||
| self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
| [ab_tuple(a=[0, {"b": 1}], b=2), | |||||
| 3, | |||||
| collections.OrderedDict([("f", 4)])]) | |||||
| shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) | |||||
| input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
| input_tree) | |||||
| self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
| [ab_tuple(a=[0, {"b": 1}], b=2), | |||||
| {"d": 3, "e": collections.OrderedDict([("f", 4)])}]) | |||||
| ## Shallow non-list edge-case. | |||||
| # Using iterable elements. | |||||
| input_tree = ["input_tree"] | |||||
| shallow_tree = "shallow_tree" | |||||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| self.assertEqual(flattened_input_tree, [input_tree]) | |||||
| self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
| input_tree = ["input_tree_0", "input_tree_1"] | |||||
| shallow_tree = "shallow_tree" | |||||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| self.assertEqual(flattened_input_tree, [input_tree]) | |||||
| self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
| # Using non-iterable elements. | |||||
| input_tree = [0] | |||||
| shallow_tree = 9 | |||||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| self.assertEqual(flattened_input_tree, [input_tree]) | |||||
| self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
| input_tree = [0, 1] | |||||
| shallow_tree = 9 | |||||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| self.assertEqual(flattened_input_tree, [input_tree]) | |||||
| self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
| ## Both non-list edge-case. | |||||
| # Using iterable elements. | |||||
| input_tree = "input_tree" | |||||
| shallow_tree = "shallow_tree" | |||||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| self.assertEqual(flattened_input_tree, [input_tree]) | |||||
| self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
| # Using non-iterable elements. | |||||
| input_tree = 0 | |||||
| shallow_tree = 0 | |||||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| self.assertEqual(flattened_input_tree, [input_tree]) | |||||
| self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
| ## Input non-list edge-case. | |||||
| # Using iterable elements. | |||||
| input_tree = "input_tree" | |||||
| shallow_tree = ["shallow_tree"] | |||||
| expected_message = ("If shallow structure is a sequence, input must also " | |||||
| "be a sequence. Input has type: <(type|class) 'str'>.") | |||||
| with self.assertRaisesRegexp(TypeError, expected_message): | |||||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| self.assertEqual(flattened_shallow_tree, shallow_tree) | |||||
| input_tree = "input_tree" | |||||
| shallow_tree = ["shallow_tree_9", "shallow_tree_8"] | |||||
| with self.assertRaisesRegexp(TypeError, expected_message): | |||||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| self.assertEqual(flattened_shallow_tree, shallow_tree) | |||||
| # Using non-iterable elements. | |||||
| input_tree = 0 | |||||
| shallow_tree = [9] | |||||
| expected_message = ("If shallow structure is a sequence, input must also " | |||||
| "be a sequence. Input has type: <(type|class) 'int'>.") | |||||
| with self.assertRaisesRegexp(TypeError, expected_message): | |||||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| self.assertEqual(flattened_shallow_tree, shallow_tree) | |||||
| input_tree = 0 | |||||
| shallow_tree = [9, 8] | |||||
| with self.assertRaisesRegexp(TypeError, expected_message): | |||||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| self.assertEqual(flattened_shallow_tree, shallow_tree) | |||||
| def testMapStructureUpTo(self): | |||||
| # Named tuples. | |||||
| ab_tuple = collections.namedtuple("ab_tuple", "a, b") | |||||
| op_tuple = collections.namedtuple("op_tuple", "add, mul") | |||||
| inp_val = ab_tuple(a=2, b=3) | |||||
| inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) | |||||
| out = nest.map_structure_up_to( | |||||
| inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops) | |||||
| self.assertEqual(out.a, 6) | |||||
| self.assertEqual(out.b, 15) | |||||
| # Lists. | |||||
| data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] | |||||
| name_list = ["evens", ["odds", "primes"]] | |||||
| out = nest.map_structure_up_to( | |||||
| name_list, lambda name, sec: "first_{}_{}".format(len(sec), name), | |||||
| name_list, data_list) | |||||
| self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]]) | |||||
| # Dicts. | |||||
| inp_val = dict(a=2, b=3) | |||||
| inp_ops = dict(a=dict(add=1, mul=2), b=dict(add=2, mul=3)) | |||||
| out = nest.map_structure_up_to( | |||||
| inp_val, | |||||
| lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||||
| self.assertEqual(out["a"], 6) | |||||
| self.assertEqual(out["b"], 15) | |||||
| # Non-equal dicts. | |||||
| inp_val = dict(a=2, b=3) | |||||
| inp_ops = dict(a=dict(add=1, mul=2), c=dict(add=2, mul=3)) | |||||
| with self.assertRaisesRegexp(ValueError, "same keys"): | |||||
| nest.map_structure_up_to( | |||||
| inp_val, | |||||
| lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||||
| # Dict+custom mapping. | |||||
| inp_val = dict(a=2, b=3) | |||||
| inp_ops = _CustomMapping(a=dict(add=1, mul=2), b=dict(add=2, mul=3)) | |||||
| out = nest.map_structure_up_to( | |||||
| inp_val, | |||||
| lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||||
| self.assertEqual(out["a"], 6) | |||||
| self.assertEqual(out["b"], 15) | |||||
| # Non-equal dict/mapping. | |||||
| inp_val = dict(a=2, b=3) | |||||
| inp_ops = _CustomMapping(a=dict(add=1, mul=2), c=dict(add=2, mul=3)) | |||||
| with self.assertRaisesRegexp(ValueError, "same keys"): | |||||
| nest.map_structure_up_to( | |||||
| inp_val, | |||||
| lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||||
| def testGetTraverseShallowStructure(self): | |||||
| scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []] | |||||
| scalar_traverse_r = nest.get_traverse_shallow_structure( | |||||
| lambda s: not isinstance(s, tuple), | |||||
| scalar_traverse_input) | |||||
| self.assertEqual(scalar_traverse_r, | |||||
| [True, True, False, [True, True], {"a": False}, []]) | |||||
| nest.assert_shallow_structure(scalar_traverse_r, | |||||
| scalar_traverse_input) | |||||
| structure_traverse_input = [(1, [2]), ([1], 2)] | |||||
| structure_traverse_r = nest.get_traverse_shallow_structure( | |||||
| lambda s: (True, False) if isinstance(s, tuple) else True, | |||||
| structure_traverse_input) | |||||
| self.assertEqual(structure_traverse_r, | |||||
| [(True, False), ([True], False)]) | |||||
| nest.assert_shallow_structure(structure_traverse_r, | |||||
| structure_traverse_input) | |||||
| with self.assertRaisesRegexp(TypeError, "returned structure"): | |||||
| nest.get_traverse_shallow_structure(lambda _: [True], 0) | |||||
| with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"): | |||||
| nest.get_traverse_shallow_structure(lambda _: 1, [1]) | |||||
| with self.assertRaisesRegexp( | |||||
| TypeError, "didn't return a depth=1 structure of bools"): | |||||
| nest.get_traverse_shallow_structure(lambda _: [1], [1]) | |||||
| def testYieldFlatStringPaths(self): | |||||
| for inputs_expected in ({"inputs": [], "expected": []}, | |||||
| {"inputs": 3, "expected": [()]}, | |||||
| {"inputs": [3], "expected": [(0,)]}, | |||||
| {"inputs": {"a": 3}, "expected": [("a",)]}, | |||||
| {"inputs": {"a": {"b": 4}}, | |||||
| "expected": [("a", "b")]}, | |||||
| {"inputs": [{"a": 2}], "expected": [(0, "a")]}, | |||||
| {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]}, | |||||
| {"inputs": [{"a": [(23, 42)]}], | |||||
| "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]}, | |||||
| {"inputs": [{"a": ([23], 42)}], | |||||
| "expected": [(0, "a", 0, 0), (0, "a", 1)]}, | |||||
| {"inputs": {"a": {"a": 2}, "c": [[[4]]]}, | |||||
| "expected": [("a", "a"), ("c", 0, 0, 0)]}, | |||||
| {"inputs": {"0": [{"1": 23}]}, | |||||
| "expected": [("0", 0, "1")]}): | |||||
| inputs = inputs_expected["inputs"] | |||||
| expected = inputs_expected["expected"] | |||||
| self.assertEqual(list(nest.yield_flat_paths(inputs)), expected) | |||||
| def testFlattenWithStringPaths(self): | |||||
| for inputs_expected in ( | |||||
| {"inputs": [], "expected": []}, | |||||
| {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]}, | |||||
| {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}): | |||||
| inputs = inputs_expected["inputs"] | |||||
| expected = inputs_expected["expected"] | |||||
| self.assertEqual( | |||||
| nest.flatten_with_joined_string_paths(inputs, separator="/"), | |||||
| expected) | |||||
| # Need a separate test for namedtuple as we can't declare tuple definitions | |||||
| # in the @parameterized arguments. | |||||
| def testFlattenNamedTuple(self): | |||||
| # pylint: disable=invalid-name | |||||
| Foo = collections.namedtuple("Foo", ["a", "b"]) | |||||
| Bar = collections.namedtuple("Bar", ["c", "d"]) | |||||
| # pylint: enable=invalid-name | |||||
| test_cases = [ | |||||
| (Foo(a=3, b=Bar(c=23, d=42)), | |||||
| [("a", 3), ("b/c", 23), ("b/d", 42)]), | |||||
| (Foo(a=Bar(c=23, d=42), b=Bar(c=0, d="something")), | |||||
| [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]), | |||||
| (Bar(c=42, d=43), | |||||
| [("c", 42), ("d", 43)]), | |||||
| (Bar(c=[42], d=43), | |||||
| [("c/0", 42), ("d", 43)]), | |||||
| ] | |||||
| for inputs, expected in test_cases: | |||||
| self.assertEqual( | |||||
| list(nest.flatten_with_joined_string_paths(inputs)), expected) | |||||
| @parameterized.named_parameters( | |||||
| ("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))), | |||||
| ("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True, | |||||
| {"a": ("a", 4), "b": ("b", 6)}), | |||||
| ("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))), | |||||
| ("nested", | |||||
| {"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True, | |||||
| {"a": [("a/0", 10), ("a/1", 12)], | |||||
| "b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]})) | |||||
| def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected): | |||||
| def format_sum(path, *values): | |||||
| return (path, sum(values)) | |||||
| result = nest.map_structure_with_paths(format_sum, s1, s2, | |||||
| check_types=check_types) | |||||
| self.assertEqual(expected, result) | |||||
| @parameterized.named_parameters( | |||||
| ("tuples", (1, 2), (3, 4, 5), ValueError), | |||||
| ("dicts", {"a": 1}, {"b": 2}, ValueError), | |||||
| ("mixed", (1, 2), [3, 4], TypeError), | |||||
| ("nested", | |||||
| {"a": [2, 3], "b": [1, 3]}, | |||||
| {"b": [5, 6, 7], "a": [8, 9]}, | |||||
| ValueError | |||||
| )) | |||||
| def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type): | |||||
| with self.assertRaises(error_type): | |||||
| nest.map_structure_with_paths(lambda path, *s: 0, s1, s2) | |||||
| class NestBenchmark(test.Benchmark): | |||||
| def run_and_report(self, s1, s2, name): | |||||
| burn_iter, test_iter = 100, 30000 | |||||
| for _ in xrange(burn_iter): | |||||
| nest.assert_same_structure(s1, s2) | |||||
| t0 = time.time() | |||||
| for _ in xrange(test_iter): | |||||
| nest.assert_same_structure(s1, s2) | |||||
| t1 = time.time() | |||||
| self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter, | |||||
| name=name) | |||||
| def benchmark_assert_structure(self): | |||||
| s1 = (((1, 2), 3), 4, (5, 6)) | |||||
| s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) | |||||
| self.run_and_report(s1, s2, "assert_same_structure_6_elem") | |||||
| s1 = (((1, 2), 3), 4, (5, 6)) * 10 | |||||
| s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10 | |||||
| self.run_and_report(s1, s2, "assert_same_structure_60_elem") | |||||
| if __name__ == "__main__": | |||||
| test.main() | |||||
| @@ -1,249 +0,0 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using System.Linq; | |||||
| using Tensorflow; | |||||
| using Tensorflow.UnitTest; | |||||
| using static Tensorflow.Binding; | |||||
| namespace TensorFlowNET.UnitTest.ops_test | |||||
| { | |||||
| /// <summary> | |||||
| /// excerpt of tensorflow/python/framework/ops_test.py | |||||
| /// </summary> | |||||
| [TestClass] | |||||
| public class ControlDependenciesTest : GraphModeTestBase | |||||
| { | |||||
| [TestMethod] | |||||
| public void TestBasic() | |||||
| { | |||||
| var g = tf.Graph().as_default(); | |||||
| Tensor a = null, b = null, c = null, d = null, e = null; | |||||
| a = constant_op.constant(1.0); | |||||
| b = constant_op.constant(1.0); | |||||
| tf_with(g.control_dependencies(new[] { a }), x => | |||||
| { | |||||
| c = constant_op.constant(1.0); | |||||
| d = array_ops.identity(b); | |||||
| e = array_ops.identity(c); | |||||
| }); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(c.op.control_inputs, new[] { a.op })); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(d.op.control_inputs, new[] { a.op })); | |||||
| // e should be dominated by c. | |||||
| Assert.AreEqual(0, e.op.control_inputs.Length); | |||||
| } | |||||
| [Ignore("How to port the ConvertibleObj?")] | |||||
| [TestMethod] | |||||
| public void TestBasicWithConversion() | |||||
| { | |||||
| var g = tf.Graph().as_default(); | |||||
| // Note: _apply_op can be replaced by g.create_op | |||||
| var a = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT }); | |||||
| // TODO: ConvertibleObj, see original source below | |||||
| /* | |||||
| def testBasicWithConversion(self): | |||||
| g = ops.Graph() | |||||
| a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) | |||||
| class ConvertibleObj(object): | |||||
| def _as_graph_element(self): | |||||
| return a | |||||
| with g.control_dependencies([ConvertibleObj()]): | |||||
| c = _apply_op(g, "FloatOutput", [], [dtypes.float32]) | |||||
| self.assertEqual(c.op.control_inputs, [a.op]) | |||||
| */ | |||||
| } | |||||
| [TestMethod] | |||||
| public void TestNested() | |||||
| { | |||||
| var g = tf.Graph().as_default(); | |||||
| var a_1 = constant_op.constant(1.0); | |||||
| var a_2 = constant_op.constant(3.0); | |||||
| var a_3 = constant_op.constant(4.0); | |||||
| var a_4 = constant_op.constant(5.0); | |||||
| Tensor b_1 = null, b_2 = null; | |||||
| tf_with(g.control_dependencies(new[] { a_1, a_2, a_3, a_4 }), ctrl => | |||||
| { | |||||
| b_1 = constant_op.constant(6.0); | |||||
| }); | |||||
| tf_with(g.control_dependencies(new[] { a_1 }), ctrl1 => | |||||
| { | |||||
| tf_with(g.control_dependencies(new[] { a_2 }), ctrl2 => | |||||
| { | |||||
| tf_with(g.control_dependencies(new[] { a_3 }), ctrl3 => | |||||
| { | |||||
| tf_with(g.control_dependencies(new[] { a_4 }), ctrl4 => | |||||
| { | |||||
| b_2 = constant_op.constant(7.0); | |||||
| }); | |||||
| }); | |||||
| }); | |||||
| }); | |||||
| //var z=tf.add(a_1, tf.multiply(b_2, b_1)); | |||||
| //with(g.control_dependencies(new[] {z}), ctrl => | |||||
| //{ | |||||
| // var z1 = tf.add(a_3, tf.multiply(a_4, a_2)); | |||||
| //}); | |||||
| //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); | |||||
| assertItemsEqual(b_1.op.control_inputs, new[] { a_1.op, a_2.op, a_3.op, a_4.op }); | |||||
| assertItemsEqual(b_2.op.control_inputs, b_1.op.control_inputs); | |||||
| } | |||||
| [TestMethod] | |||||
| public void TestClear() | |||||
| { | |||||
| var g = tf.Graph().as_default(); | |||||
| var a_1 = constant_op.constant(1.0); | |||||
| var a_2 = constant_op.constant(3.0); | |||||
| var a_3 = constant_op.constant(4.0); | |||||
| var a_4 = constant_op.constant(5.0); | |||||
| Operation b_3_4 = null, b_3 = null, b_none = null, b_1 = null, b_1_2 = null, b_none2 = null; | |||||
| tf_with(g.control_dependencies(new[] { a_1 }), ctrl1 => | |||||
| { | |||||
| tf_with(g.control_dependencies(new[] { a_2 }), ctrl2 => | |||||
| { | |||||
| tf_with(g.control_dependencies(null), ctrl3 => | |||||
| { | |||||
| tf_with(g.control_dependencies(new[] { a_3 }), ctrl4 => | |||||
| { | |||||
| tf_with(g.control_dependencies(new[] { a_4 }), ctrl5 => | |||||
| { | |||||
| // deps [a_3, a_4] | |||||
| b_3_4 = constant_op.constant(7.0); | |||||
| }); | |||||
| // deps = [a_3] | |||||
| b_3 = constant_op.constant(8.0); | |||||
| }); | |||||
| // deps back to None | |||||
| b_none = constant_op.constant(9.0); | |||||
| }); | |||||
| // deps back to [a_1, a_2] | |||||
| b_1_2 = constant_op.constant(10.0); | |||||
| }); | |||||
| // deps back to [a_1] | |||||
| b_1 = constant_op.constant(11.0); | |||||
| tf_with(g.control_dependencies(null), ctrl6 => | |||||
| { | |||||
| // deps are None again | |||||
| b_none2 = constant_op.constant(12.0); | |||||
| }); | |||||
| }); | |||||
| // Note assertItemsEqual(given, expected), expected and given parameters should be swapped below | |||||
| assertItemsEqual(new[] { a_3.op, a_4.op }, b_3_4.op.control_inputs); | |||||
| assertItemsEqual(new[] { a_3.op }, b_3.op.control_inputs); | |||||
| assertItemsEqual(new object[0], b_none.op.control_inputs); | |||||
| assertItemsEqual(new[] { a_1.op, a_2.op }, b_1_2.op.control_inputs); | |||||
| assertItemsEqual(new[] { a_1.op }, b_1.op.control_inputs); | |||||
| assertItemsEqual(new object[0], b_none2.op.control_inputs); | |||||
| } | |||||
| [TestMethod] | |||||
| public void TestComplex() | |||||
| { | |||||
| var g = tf.Graph().as_default(); | |||||
| // Usage pattern: | |||||
| // * Nodes a_i are constants defined at the outermost scope, and are used | |||||
| // as control inputs for the ith nested scope. | |||||
| // * Nodes b_i are defined as Mul(a_3, a_4) at each scope. | |||||
| // * Nodes c_i are defined as Mul(a_1, b_1) at each scope. | |||||
| // * Nodes d_i are defined as Mul(b_i, c_i) at each scope. | |||||
| // * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1. | |||||
| var a_1 = constant_op.constant(1.0); | |||||
| var a_2 = constant_op.constant(2.0); | |||||
| var a_3 = constant_op.constant(3.0); | |||||
| var a_4 = constant_op.constant(4.0); | |||||
| Operation b_1 = null, b_2 = null, b_3 = null, b_4 = null; | |||||
| Operation c_1 = null, c_2 = null, c_3 = null, c_4 = null; | |||||
| Operation d_1 = null, d_2 = null, d_3 = null, d_4 = null; | |||||
| Operation e_1 = null, e_2 = null, e_3 = null, e_4 = null; | |||||
| tf_with(g.control_dependencies(new[] { a_1 }), ctrl1 => | |||||
| { | |||||
| b_1 = tf.multiply(a_3, a_4); | |||||
| c_1 = tf.multiply(a_1, b_1.output); | |||||
| d_1 = tf.multiply(b_1.output, c_1.output); | |||||
| e_1 = constant_op.constant(5.0); | |||||
| tf_with(g.control_dependencies(new[] { a_2 }), ctrl2 => | |||||
| { | |||||
| b_2 = tf.multiply(a_3, a_4); | |||||
| c_2 = tf.multiply(a_1, b_1.output); | |||||
| d_2 = tf.multiply(b_2.output, c_2.output); | |||||
| e_2 = tf.multiply(e_1.output, e_1.output); | |||||
| tf_with(g.control_dependencies(new[] { a_3 }), ctrl3 => | |||||
| { | |||||
| b_3 = tf.multiply(a_3, a_4); | |||||
| c_3 = tf.multiply(a_1, b_1.output); | |||||
| d_3 = tf.multiply(b_3.output, c_3.output); | |||||
| e_3 = tf.multiply(e_2.output, e_2.output); | |||||
| tf_with(g.control_dependencies(new[] { a_4 }), ctrl4 => | |||||
| { | |||||
| b_4 = tf.multiply(a_3, a_4); | |||||
| c_4 = tf.multiply(a_1, b_1.output); | |||||
| d_4 = tf.multiply(b_4.output, c_4.output); | |||||
| e_4 = tf.multiply(e_3.output, e_3.output); | |||||
| }); | |||||
| }); | |||||
| }); | |||||
| }); | |||||
| // Note assertItemsEqual(given, expected), expected and given parameters should be swapped below | |||||
| assertItemsEqual(new[] { a_1.op }, b_1.op.control_inputs); | |||||
| assertItemsEqual(new[] { a_1.op, a_2.op }, b_2.op.control_inputs); | |||||
| assertItemsEqual(new[] { a_1.op, a_2.op }, b_3.op.control_inputs); | |||||
| assertItemsEqual(new[] { a_1.op, a_2.op }, b_4.op.control_inputs); | |||||
| assertItemsEqual(new object[0], c_1.op.control_inputs); | |||||
| assertItemsEqual(new[] { a_2.op }, c_2.op.control_inputs); | |||||
| assertItemsEqual(new[] { a_2.op, a_3.op }, c_3.op.control_inputs); | |||||
| assertItemsEqual(new[] { a_2.op, a_3.op, a_4.op }, c_4.op.control_inputs); | |||||
| assertItemsEqual(new object[0], d_1.op.control_inputs); | |||||
| assertItemsEqual(new object[0], d_2.op.control_inputs); | |||||
| assertItemsEqual(new object[0], d_3.op.control_inputs); | |||||
| assertItemsEqual(new object[0], d_4.op.control_inputs); | |||||
| assertItemsEqual(new[] { a_1.op }, e_1.op.control_inputs); | |||||
| assertItemsEqual(new[] { a_2.op }, e_2.op.control_inputs); | |||||
| assertItemsEqual(new[] { a_3.op }, e_3.op.control_inputs); | |||||
| assertItemsEqual(new[] { a_4.op }, e_4.op.control_inputs); | |||||
| } | |||||
| [Ignore("Don't know how to create an operation with two outputs")] | |||||
| [TestMethod] | |||||
| public void TestRepeatedDependency() | |||||
| { | |||||
| /* | |||||
| def testRepeatedDependency(self): | |||||
| g = ops.Graph() | |||||
| a = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32]) | |||||
| a_0, a_1 = a.outputs | |||||
| with g.control_dependencies([a_0]): | |||||
| b = _apply_op(g, "FloatOutput", [], [dtypes.float32]) | |||||
| with g.control_dependencies([a_1]): | |||||
| c = _apply_op(g, "FloatOutput", [], [dtypes.float32]) | |||||
| self.assertEqual(b.op.control_inputs, [a]) | |||||
| self.assertEqual(c.op.control_inputs, [a]) | |||||
| */ | |||||
| } | |||||
| [TestMethod] | |||||
| public void TestNoControlDependencyWithDataDependency() | |||||
| { | |||||
| var g = tf.Graph().as_default(); | |||||
| Operation b = null; | |||||
| var a = constant_op.constant(100.0); | |||||
| tf_with(g.control_dependencies(new[] { a }), ctrl1 => | |||||
| { | |||||
| b = array_ops.identity(a); | |||||
| }); | |||||
| Assert.AreEqual(0, b.op.control_inputs.Length); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,222 +0,0 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using System; | |||||
| using System.Linq; | |||||
| using Tensorflow; | |||||
| using Tensorflow.Operations; | |||||
| using Tensorflow.UnitTest; | |||||
| using static Tensorflow.Binding; | |||||
| namespace TensorFlowNET.UnitTest.ops_test | |||||
| { | |||||
| /// <summary> | |||||
| /// excerpt of tensorflow/python/framework/ops_test.py | |||||
| /// # These cases test the private Graph._create_op_from_tf_operation | |||||
| /// # method. Arguably we should only test the public APIs that depend on this | |||||
| /// # method. However, this logic is complex and tricky, and it can be difficult to | |||||
| /// # ascertain if we have adequate coverage (e.g. a graph may run successfully if | |||||
| /// # the control flow context isn't set properly, but a more complicated use case | |||||
| /// # that might not be obvious to test will fail). Thus we instead explicitly test | |||||
| /// # the low-level behavior. | |||||
| /// </summary> | |||||
| [Ignore] | |||||
| [TestClass] | |||||
| public class CreateOpFromTfOperationTest : GraphModeTestBase | |||||
| { | |||||
| [TestMethod] | |||||
| public void TestShape() | |||||
| { | |||||
| using (var g = tf.Graph().as_default()) | |||||
| { | |||||
| var x = constant_op.constant(new[,] { { 1, 2, 3 }, { 4, 5, 6 } }); | |||||
| var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] { x }, new Operation[0]); | |||||
| var op = g._create_op_from_tf_operation(c_op); | |||||
| Assert.AreEqual("myop", op.name); | |||||
| Assert.AreEqual("Identity", op.type); | |||||
| Assert.AreEqual(1, len(op.outputs)); | |||||
| assertItemsEqual(new[] { 2, 3 }, op.outputs[0].shape); | |||||
| } | |||||
| } | |||||
| [TestMethod] | |||||
| public void TestUniqueName() | |||||
| { | |||||
| var graph = tf.Graph().as_default(); | |||||
| //var (c_op,op_desc) = ops._create_c_op(g, ops._NodeDef("Const", "myop"), new Tensor[0], new Operation[0]); | |||||
| //var (c_op2, op_desc1) = ops._create_c_op(g, ops._NodeDef("Const", "myop_1"), new Tensor[0], new Operation[0]); | |||||
| //var op = g._create_op_from_tf_operation(c_op); | |||||
| //var op2 = g._create_op_from_tf_operation(c_op2); | |||||
| var op = constant_op.constant(0, name: "myop").op; | |||||
| var op2 = constant_op.constant(0, name: "myop_1").op; | |||||
| // Create ops with same names as op1 and op2. We expect the new names to be | |||||
| // uniquified. | |||||
| var op3 = constant_op.constant(0, name: "myop").op; | |||||
| var op4 = constant_op.constant(0, name: "myop_1").op; | |||||
| self.assertEqual(op.name, "myop"); | |||||
| self.assertEqual(op2.name, "myop_1"); | |||||
| self.assertEqual(op3.name, "myop_2"); | |||||
| self.assertEqual(op4.name, "myop_1_1"); | |||||
| } | |||||
| [Ignore("need tesnroflow expose UpdateEdge API")] | |||||
| [TestMethod] | |||||
| public void TestCond() | |||||
| { | |||||
| var g = tf.Graph().as_default(); | |||||
| var x = constant_op.constant(10); | |||||
| var true_fn = new Func<Tensor>(() => | |||||
| { | |||||
| var c_op = ops._create_c_op(g, ops._NodeDef("Identity", "cond/myop"), new[] { x }, new Operation[0]); | |||||
| var new_ops = g._add_new_tf_operations(); | |||||
| self.assertEqual(len(new_ops), 1); | |||||
| return x; | |||||
| }); | |||||
| control_flow_ops.cond(x < 10, true_fn, () => x); | |||||
| var op = g.get_operation_by_name("cond/myop"); | |||||
| //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta.txt", as_text:true); | |||||
| //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); | |||||
| self.assertIsNotNone(op); | |||||
| self.assertEqual(op.name, "cond/myop"); | |||||
| self.assertEqual(op.type, "Identity"); | |||||
| //self.assertEqual(op.outputs, new object[0]); | |||||
| var op_input = op.inputs[0].op; | |||||
| self.assertEqual(op_input.type, "Switch"); | |||||
| self.assertEqual(op_input.inputs[0].name, x.name); | |||||
| self.assertEqual(op.graph, g); | |||||
| self.assertIsNotNone(op._get_control_flow_context()); | |||||
| var cond_text = op._get_control_flow_context() as ControlFlowContext; | |||||
| self.assertEqual(cond_text.Name, "cond/cond_text"); | |||||
| } | |||||
| [Ignore("Todo: Port")] | |||||
| [TestMethod] | |||||
| public void TestWhileLoop() | |||||
| { | |||||
| var graph = tf.Graph().as_default(); | |||||
| Operation x = null; | |||||
| x = constant_op.constant(42); | |||||
| var body = new Func<int, int>(i => | |||||
| { | |||||
| ops._create_c_op(ops.get_default_graph(), ops._NodeDef("Identity", "myloop/myop"), new[] { x.output }, | |||||
| new Operation[0]); | |||||
| var new_ops = graph._add_new_tf_operations(); | |||||
| self.assertEqual(len(new_ops), 1); | |||||
| return i; | |||||
| }); | |||||
| // TODO: port control_flow_ops.while_loop | |||||
| //control_flow_ops.while_loop( i => i < 10, body, new int[]{0}, name = "myloop"); | |||||
| var op = graph.get_operation_by_name("myloop/myop"); | |||||
| self.assertIsNotNone(op); | |||||
| self.assertEqual(op.name, "myloop/myop"); | |||||
| self.assertEqual(op.type, "Identity"); | |||||
| self.assertEqual(op.outputs.Length, 0); | |||||
| var op_input = op.inputs[0].op; | |||||
| self.assertEqual(op_input.type, "Enter"); | |||||
| self.assertItemsEqual(op_input.inputs.OfType<Operation>().ToArray(), new[] { x }); | |||||
| self.assertEqual(op.graph, graph); | |||||
| self.assertIsNotNone(op._get_control_flow_context()); | |||||
| self.assertEqual(((ControlFlowContext)op._get_control_flow_context()).Name, "myloop/while_context"); | |||||
| /* | |||||
| @test_util.run_v1_only("b/120545219") | |||||
| def testWhileLoop(self): | |||||
| g = ops.Graph() | |||||
| with g.as_default(): | |||||
| x = test_ops.int_output() | |||||
| def body(i): | |||||
| ops._create_c_op(ops.get_default_graph(), | |||||
| ops._NodeDef("IntInput", "myloop/myop"), [x], []) | |||||
| new_ops = g._add_new_tf_operations() | |||||
| self.assertEqual(len(new_ops), 1) | |||||
| return i | |||||
| control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") | |||||
| op = g.get_operation_by_name("myloop/myop") | |||||
| self.assertIsNotNone(op) | |||||
| self.assertEqual(op.name, "myloop/myop") | |||||
| self.assertEqual(op.type, "IntInput") | |||||
| self.assertEqual(op.outputs, []) | |||||
| op_input = op.inputs[0].op | |||||
| self.assertEqual(op_input.type, "Enter") | |||||
| self.assertEqual(list(op_input.inputs), [x]) | |||||
| self.assertEqual(op.graph, g) | |||||
| # pylint: disable=protected-access | |||||
| self.assertIsNotNone(op._get_control_flow_context()) | |||||
| self.assertEqual(op._get_control_flow_context().name, | |||||
| "myloop/while_context") | |||||
| # pylint: enable=protected-access | |||||
| */ | |||||
| } | |||||
| [Ignore("Todo: Port")] | |||||
| [TestMethod] | |||||
| public void TestWhileLoopWithInternalControlDep() | |||||
| { | |||||
| /* | |||||
| @test_util.run_v1_only("b/120545219") | |||||
| def testWhileLoopWithInternalControlDep(self): | |||||
| g = ops.Graph() | |||||
| with g.as_default(): | |||||
| x = test_ops.int_output() | |||||
| def body(i): | |||||
| c = constant_op.constant(1.0, name="c") | |||||
| ops._create_c_op(ops.get_default_graph(), | |||||
| ops._NodeDef("IntInput", "myloop/myop"), [x], []) | |||||
| with ops.control_dependencies([c]): | |||||
| new_ops = g._add_new_tf_operations() | |||||
| self.assertEqual(len(new_ops), 1) | |||||
| return i | |||||
| control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") | |||||
| op = g.get_operation_by_name("myloop/myop") | |||||
| self.assertIsNotNone(op) | |||||
| c = g.get_operation_by_name("myloop/c") | |||||
| self.assertIsNotNone(c) | |||||
| # Internal control dep is preserved | |||||
| self.assertEqual(op.control_inputs, [c]) | |||||
| */ | |||||
| } | |||||
| [Ignore("Todo: Port")] | |||||
| [TestMethod] | |||||
| public void TestWhileLoopWithExternalControlDep() | |||||
| { | |||||
| /* | |||||
| @test_util.run_v1_only("b/120545219") | |||||
| def testWhileLoopWithExternalControlDep(self): | |||||
| g = ops.Graph() | |||||
| with g.as_default(): | |||||
| x = test_ops.int_output() | |||||
| c = constant_op.constant(1.0) | |||||
| def body(i): | |||||
| ops._create_c_op(ops.get_default_graph(), | |||||
| ops._NodeDef("IntInput", "myloop/myop"), [x], []) | |||||
| with ops.control_dependencies([c]): | |||||
| new_ops = g._add_new_tf_operations() | |||||
| self.assertEqual(len(new_ops), 1) | |||||
| return i | |||||
| control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") | |||||
| op = g.get_operation_by_name("myloop/myop") | |||||
| self.assertIsNotNone(op) | |||||
| # External control dep is removed and replaced with internal control dep | |||||
| self.assertNotEqual(op.control_inputs[0], c.op) | |||||
| self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context()) | |||||
| */ | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,196 +0,0 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using Tensorflow; | |||||
| using Tensorflow.UnitTest; | |||||
| namespace TensorFlowNET.UnitTest.ops_test | |||||
| { | |||||
| /// <summary> | |||||
| /// excerpt of tensorflow/python/framework/ops_test.py | |||||
| /// </summary> | |||||
| [TestClass] | |||||
| public class GraphTest : GraphModeTestBase | |||||
| { | |||||
| [TestInitialize] | |||||
| public void SetUp() | |||||
| { | |||||
| ops.reset_default_graph(); | |||||
| } | |||||
| [TestCleanup] | |||||
| public void TearDown() | |||||
| { | |||||
| ops.reset_default_graph(); | |||||
| } | |||||
| private void _AssertDefault(Graph expected) | |||||
| { | |||||
| Assert.AreSame(ops.get_default_graph(), expected); | |||||
| } | |||||
| [Ignore("Todo: Port")] | |||||
| [TestMethod] | |||||
| public void testResetDefaultGraphNesting() | |||||
| { | |||||
| /* | |||||
| def testResetDefaultGraphNesting(self): | |||||
| g0 = ops.Graph() | |||||
| with self.assertRaises(AssertionError): | |||||
| with g0.as_default(): | |||||
| ops.reset_default_graph() | |||||
| */ | |||||
| } | |||||
| [Ignore("Todo: Port")] | |||||
| [TestMethod] | |||||
| public void testGraphContextManagerCancelsEager() | |||||
| { | |||||
| /* | |||||
| def testGraphContextManagerCancelsEager(self): | |||||
| with context.eager_mode(): | |||||
| with ops.Graph().as_default(): | |||||
| self.assertFalse(context.executing_eagerly()) | |||||
| */ | |||||
| } | |||||
| [Ignore("Todo: Port")] | |||||
| [TestMethod] | |||||
| public void testGraphContextManager() | |||||
| { | |||||
| /* | |||||
| def testGraphContextManager(self): | |||||
| g0 = ops.Graph() | |||||
| with g0.as_default() as g1: | |||||
| self.assertIs(g0, g1) | |||||
| */ | |||||
| } | |||||
| [Ignore("Todo: Port")] | |||||
| [TestMethod] | |||||
| public void testDefaultGraph() | |||||
| { | |||||
| /* | |||||
| def testDefaultGraph(self): | |||||
| orig = ops.get_default_graph() | |||||
| self._AssertDefault(orig) | |||||
| g0 = ops.Graph() | |||||
| self._AssertDefault(orig) | |||||
| context_manager_0 = g0.as_default() | |||||
| self._AssertDefault(orig) | |||||
| with context_manager_0 as g0: | |||||
| self._AssertDefault(g0) | |||||
| with ops.Graph().as_default() as g1: | |||||
| self._AssertDefault(g1) | |||||
| self._AssertDefault(g0) | |||||
| self._AssertDefault(orig) | |||||
| */ | |||||
| } | |||||
| [Ignore("Todo: Port")] | |||||
| [TestMethod] | |||||
| public void testPreventFeeding() | |||||
| { | |||||
| /* | |||||
| def testPreventFeeding(self): | |||||
| g = ops.Graph() | |||||
| a = constant_op.constant(2.0) | |||||
| self.assertTrue(g.is_feedable(a)) | |||||
| g.prevent_feeding(a) | |||||
| self.assertFalse(g.is_feedable(a)) | |||||
| */ | |||||
| } | |||||
| [Ignore("Todo: Port")] | |||||
| [TestMethod] | |||||
| public void testAsGraphElementConversions() | |||||
| { | |||||
| /* | |||||
| def testAsGraphElementConversions(self): | |||||
| class ConvertibleObj(object): | |||||
| def _as_graph_element(self): | |||||
| return "FloatOutput:0" | |||||
| class NonConvertibleObj(object): | |||||
| pass | |||||
| g = ops.Graph() | |||||
| a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) | |||||
| self.assertEqual(a, g.as_graph_element(ConvertibleObj())) | |||||
| with self.assertRaises(TypeError): | |||||
| g.as_graph_element(NonConvertibleObj()) | |||||
| */ | |||||
| } | |||||
| [Ignore("Todo: Port")] | |||||
| [TestMethod] | |||||
| public void testGarbageCollected() | |||||
| { | |||||
| /* | |||||
| # Regression test against creating custom __del__ functions in classes | |||||
| # involved in cyclic references, e.g. Graph and Operation. (Python won't gc | |||||
| # cycles that require calling a __del__ method, because the __del__ method can | |||||
| # theoretically increase the object's refcount to "save" it from gc, and any | |||||
| # already-deleted objects in the cycle would have be to restored.) | |||||
| def testGarbageCollected(self): | |||||
| # Create a graph we can delete and a weak reference to monitor if it's gc'd | |||||
| g = ops.Graph() | |||||
| g_ref = weakref.ref(g) | |||||
| # Create some ops | |||||
| with g.as_default(): | |||||
| a = constant_op.constant(2.0) | |||||
| b = constant_op.constant(3.0) | |||||
| c = math_ops.add(a, b) | |||||
| # Create a session we can delete | |||||
| with session.Session(graph=g) as sess: | |||||
| self.evaluate(c) | |||||
| # Delete all references and trigger gc | |||||
| del g | |||||
| del a | |||||
| del b | |||||
| del c | |||||
| del sess | |||||
| gc.collect() | |||||
| self.assertIsNone(g_ref()) | |||||
| */ | |||||
| } | |||||
| [Ignore("Todo: Port")] | |||||
| [TestMethod] | |||||
| public void testRunnableAfterInvalidShape() | |||||
| { | |||||
| /* | |||||
| def testRunnableAfterInvalidShape(self): | |||||
| with ops.Graph().as_default(): | |||||
| with self.assertRaises(ValueError): | |||||
| math_ops.add([1, 2], [1, 2, 3]) | |||||
| a = constant_op.constant(1) | |||||
| with session.Session() as sess: | |||||
| self.evaluate(a) | |||||
| */ | |||||
| } | |||||
| [Ignore("Todo: Port")] | |||||
| [TestMethod] | |||||
| public void testRunnableAfterInvalidShapeWithKernelLabelMap() | |||||
| { | |||||
| /* | |||||
| def testRunnableAfterInvalidShapeWithKernelLabelMap(self): | |||||
| g = ops.Graph() | |||||
| with g.as_default(): | |||||
| with g._kernel_label_map({"KernelLabelRequired": "overload_1"}): | |||||
| with self.assertRaises(ValueError): | |||||
| test_ops.kernel_label_required(1) | |||||
| a = constant_op.constant(1) | |||||
| with session.Session() as sess: | |||||
| self.evaluate(a) | |||||
| */ | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,26 +0,0 @@ | |||||
| | |||||
| import tensorflow as tf | |||||
| # Create some variables. | |||||
| v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer) | |||||
| v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer) | |||||
| inc_v1 = v1.assign(v1+1) | |||||
| dec_v2 = v2.assign(v2-1) | |||||
| # Add an op to initialize the variables. | |||||
| init_op = tf.global_variables_initializer() | |||||
| # Add ops to save and restore all the variables. | |||||
| saver = tf.train.Saver() | |||||
| # Later, launch the model, initialize the variables, do some work, and save the | |||||
| # variables to disk. | |||||
| with tf.Session() as sess: | |||||
| sess.run(init_op) | |||||
| # Do some work with the model. | |||||
| inc_v1.op.run() | |||||
| dec_v2.op.run() | |||||
| # Save the variables to disk. | |||||
| save_path = saver.save(sess, "/tmp/model.ckpt") | |||||
| print("Model saved in path: %s" % save_path) | |||||