| @@ -17,6 +17,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Text", "src\Tens | |||
| EndProject | |||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Recommenders", "src\TensorFlowNET.Recommenders\Tensorflow.Recommenders.csproj", "{F17AAECB-960A-4E18-A270-BAD776F0E55B}" | |||
| EndProject | |||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Native.UnitTest", "test\TensorFlowNET.Native.UnitTest\Tensorflow.Native.UnitTest.csproj", "{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}" | |||
| EndProject | |||
| Global | |||
| GlobalSection(SolutionConfigurationPlatforms) = preSolution | |||
| Debug|Any CPU = Debug|Any CPU | |||
| @@ -107,8 +109,8 @@ Global | |||
| {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x86.Build.0 = Release|Any CPU | |||
| {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||
| {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
| {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x64.ActiveCfg = Debug|Any CPU | |||
| {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x64.Build.0 = Debug|Any CPU | |||
| {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x64.ActiveCfg = Debug|x64 | |||
| {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x64.Build.0 = Debug|x64 | |||
| {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x86.ActiveCfg = Debug|Any CPU | |||
| {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x86.Build.0 = Debug|Any CPU | |||
| {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||
| @@ -155,8 +157,8 @@ Global | |||
| {49D71826-C03D-4FA7-9BAC-22C1327E65CF}.Release|x86.Build.0 = Release|Any CPU | |||
| {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||
| {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
| {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x64.ActiveCfg = Debug|Any CPU | |||
| {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x64.Build.0 = Debug|Any CPU | |||
| {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x64.ActiveCfg = Debug|x64 | |||
| {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x64.Build.0 = Debug|x64 | |||
| {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x86.ActiveCfg = Debug|Any CPU | |||
| {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x86.Build.0 = Debug|Any CPU | |||
| {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||
| @@ -179,8 +181,8 @@ Global | |||
| {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Release|x86.Build.0 = Release|Any CPU | |||
| {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||
| {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
| {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x64.ActiveCfg = Debug|Any CPU | |||
| {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x64.Build.0 = Debug|Any CPU | |||
| {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x64.ActiveCfg = Debug|x64 | |||
| {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x64.Build.0 = Debug|x64 | |||
| {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x86.ActiveCfg = Debug|Any CPU | |||
| {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x86.Build.0 = Debug|Any CPU | |||
| {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||
| @@ -201,6 +203,30 @@ Global | |||
| {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|x64.Build.0 = Release|Any CPU | |||
| {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|x86.ActiveCfg = Release|Any CPU | |||
| {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|x86.Build.0 = Release|Any CPU | |||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|x64.ActiveCfg = Debug|x64 | |||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|x64.Build.0 = Debug|x64 | |||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|x86.ActiveCfg = Debug|Any CPU | |||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|x86.Build.0 = Debug|Any CPU | |||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|x86.ActiveCfg = Debug|Any CPU | |||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|x86.Build.0 = Debug|Any CPU | |||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|Any CPU.ActiveCfg = Debug|Any CPU | |||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|Any CPU.Build.0 = Debug|Any CPU | |||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|x64.ActiveCfg = Debug|Any CPU | |||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|x64.Build.0 = Debug|Any CPU | |||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|x86.ActiveCfg = Debug|Any CPU | |||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|x86.Build.0 = Debug|Any CPU | |||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|Any CPU.Build.0 = Release|Any CPU | |||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x64.ActiveCfg = Release|Any CPU | |||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x64.Build.0 = Release|Any CPU | |||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x86.ActiveCfg = Release|Any CPU | |||
| {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x86.Build.0 = Release|Any CPU | |||
| EndGlobalSection | |||
| GlobalSection(SolutionProperties) = preSolution | |||
| HideSolutionNode = FALSE | |||
| @@ -5,6 +5,7 @@ | |||
| <TargetFramework>netcoreapp3.1</TargetFramework> | |||
| <RootNamespace>Tensorflow</RootNamespace> | |||
| <AssemblyName>Tensorflow</AssemblyName> | |||
| <Platforms>AnyCPU;x64</Platforms> | |||
| </PropertyGroup> | |||
| <ItemGroup> | |||
| @@ -28,17 +28,10 @@ namespace Tensorflow | |||
| => ops.reset_default_graph(); | |||
| public Graph get_default_graph() | |||
| { | |||
| return ops.get_default_graph(); | |||
| } | |||
| => ops.get_default_graph(); | |||
| /// <summary> | |||
| /// Equivalent to <see cref="get_default_graph"/> but does not create a new graph if it there is none. | |||
| /// </summary> | |||
| public Graph peak_default_graph() | |||
| { | |||
| return ops.default_graph_stack.peak_controller(); | |||
| } | |||
| => ops.peak_default_graph(); | |||
| /// <summary> | |||
| /// Creates a new graph. | |||
| @@ -37,19 +37,7 @@ namespace Tensorflow.Contexts | |||
| if (shouldRunInEager) | |||
| return eagerAction(); | |||
| else | |||
| { | |||
| if (executing_eagerly()) | |||
| { | |||
| graph_mode(); | |||
| var result = graphAction(); | |||
| restore_mode(); | |||
| return result; | |||
| } | |||
| else | |||
| { | |||
| return graphAction(); | |||
| } | |||
| } | |||
| return graphAction(); | |||
| } | |||
| // [DebuggerStepThrough] | |||
| @@ -80,9 +80,12 @@ namespace Tensorflow.Contexts | |||
| /// Checks whether the current thread has eager execution enabled. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| [DebuggerStepThrough] | |||
| // [DebuggerStepThrough] | |||
| public bool executing_eagerly() | |||
| { | |||
| if(context_switches.Count() == 0) | |||
| tf.enable_eager_execution(); | |||
| return context_switches.Current().EagerMode; | |||
| } | |||
| @@ -103,11 +106,16 @@ namespace Tensorflow.Contexts | |||
| public void restore_mode() | |||
| { | |||
| context_switches.Pop(); | |||
| tf.get_default_graph(); | |||
| } | |||
| public void reset_context() | |||
| { | |||
| c_api.TFE_ContextClearCaches(_handle); | |||
| ops.reset_uid(); | |||
| ops.reset_default_graph(); | |||
| context_switches.Clear(); | |||
| if (_handle != null) | |||
| c_api.TFE_ContextClearCaches(_handle); | |||
| } | |||
| public void Dispose() | |||
| @@ -40,11 +40,21 @@ namespace Tensorflow.Contexts | |||
| }); | |||
| } | |||
| public void Clear() | |||
| { | |||
| stack.Clear(); | |||
| } | |||
| public void Pop() | |||
| { | |||
| stack.Pop(); | |||
| } | |||
| public int Count() | |||
| { | |||
| return stack.Count; | |||
| } | |||
| public ContextSwitch Current() | |||
| { | |||
| return stack.Peek(); | |||
| @@ -15,11 +15,13 @@ namespace Tensorflow | |||
| bool preserve_cardinality = false, | |||
| bool use_legacy_function = false) : base(input_dataset) | |||
| { | |||
| using var func = new ConcreteFunction($"autograph_{map_func.Method.Name}"); | |||
| using var func = new ConcreteFunction($"{map_func.Method.Name}_{Guid.NewGuid()}"); | |||
| func.Enter(); | |||
| var input = tf.placeholder(input_dataset.element_spec[0].dtype); | |||
| var output = map_func(input); | |||
| func.ToGraph(input, output); | |||
| func.Exit(); | |||
| structure = func.OutputStructure; | |||
| variant_tensor = ops.map_dataset(input_dataset.variant_tensor, | |||
| @@ -34,7 +34,6 @@ namespace Tensorflow.Functions | |||
| public ConcreteFunction(string name) | |||
| { | |||
| func_graph = new FuncGraph(name); | |||
| func_graph.as_default(); | |||
| } | |||
| public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs = null) | |||
| @@ -46,7 +45,7 @@ namespace Tensorflow.Functions | |||
| public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) | |||
| { | |||
| string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | |||
| string func_name = $"{func.Method.Name}_{Guid.NewGuid()}"; | |||
| // IntPtr func_handle; | |||
| using var graph = new FuncGraph(func_name); | |||
| @@ -59,11 +58,12 @@ namespace Tensorflow.Functions | |||
| new[] { input }, | |||
| new[] { output }, | |||
| null); | |||
| graph.Exit(); | |||
| } | |||
| public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype) | |||
| { | |||
| string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | |||
| string func_name = $"{func.Method.Name}_{Guid.NewGuid()}"; | |||
| // IntPtr func_handle; | |||
| using var graph = new FuncGraph(func_name); | |||
| @@ -79,12 +79,13 @@ namespace Tensorflow.Functions | |||
| new[] { input }, | |||
| new[] { output.variant_tensor }, | |||
| null); | |||
| graph.Exit(); | |||
| } | |||
| public ConcreteFunction(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> func, | |||
| TF_DataType[] dtypes, TensorShape[] shapes) | |||
| { | |||
| string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | |||
| string func_name = $"{func.Method.Name}_{Guid.NewGuid()}"; | |||
| // IntPtr func_handle; | |||
| using var graph = new FuncGraph(func_name); | |||
| @@ -103,6 +104,7 @@ namespace Tensorflow.Functions | |||
| new[] { input1, input2, input3 }, | |||
| new[] { outputs.Item1, outputs.Item2 }, | |||
| null); | |||
| graph.Exit(); | |||
| } | |||
| public void ToGraph(Tensors inputs, Tensors outputs) | |||
| @@ -112,10 +114,19 @@ namespace Tensorflow.Functions | |||
| inputs, | |||
| outputs, | |||
| null); | |||
| OutputStructure = outputs.Select(x => x.ToTensorSpec()).ToArray(); | |||
| } | |||
| public void Enter() | |||
| { | |||
| func_graph.as_default(); | |||
| } | |||
| public void Exit() | |||
| { | |||
| func_graph.Exit(); | |||
| } | |||
| public Tensors Invoke(Tensors inputs) | |||
| { | |||
| var forward_backward = SelectForwardAndBackwardFunctions(inputs, 1, tf.Context.executing_eagerly()); | |||
| @@ -26,7 +26,6 @@ namespace Tensorflow.Functions | |||
| var output_names = new string[0]; | |||
| _func_graph = new FuncGraph(graph, name, attrs); | |||
| _func_graph.as_default(); | |||
| _func_graph.ToGraph(operations, inputs, outputs, output_names); | |||
| } | |||
| @@ -84,7 +84,7 @@ namespace Tensorflow.Functions | |||
| } | |||
| var gradients_wrt_outputs = new List<Tensor>(); | |||
| var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{ops.uid()}"); | |||
| var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"); | |||
| backwards_graph.as_default(); | |||
| foreach (var output in trainable_outputs) | |||
| gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape)); | |||
| @@ -101,6 +101,7 @@ namespace Tensorflow.Functions | |||
| if (!_func_graph.Outputs.Contains(capture)) | |||
| _func_graph.Outputs.Add(capture); | |||
| } | |||
| backwards_graph.Exit(); | |||
| var forward_function_name = $"{_FORWARD_PREFIX}_{ops.uid()}"; | |||
| var backward_function_attr = new Dictionary<string, string>(); | |||
| @@ -8,7 +8,7 @@ namespace Tensorflow.Graphs | |||
| { | |||
| public Func<Tensor, Tensor> to_graph(Func<Tensor, Tensor> func) | |||
| { | |||
| string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | |||
| string func_name = $"{func.Method.Name}_{Guid.NewGuid()}"; | |||
| // IntPtr func_handle; | |||
| using (var graph = new FuncGraph(func_name)) | |||
| @@ -22,6 +22,7 @@ namespace Tensorflow.Graphs | |||
| new[] { input }, | |||
| new[] { output }, | |||
| null); | |||
| graph.Exit(); | |||
| } | |||
| @@ -39,7 +40,7 @@ namespace Tensorflow.Graphs | |||
| public Func<Tensor, Tensor, Tensor> to_graph(Func<Tensor, Tensor, Tensor> func) | |||
| { | |||
| string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | |||
| string func_name = $"{func.Method.Name}_{Guid.NewGuid()}"; | |||
| // IntPtr func_handle; | |||
| using (var graph = new FuncGraph(func_name)) | |||
| @@ -54,6 +55,7 @@ namespace Tensorflow.Graphs | |||
| new[] { input1, input2 }, | |||
| new[] { output }, | |||
| null); | |||
| graph.Exit(); | |||
| } | |||
| return (Tensor a, Tensor b) => | |||
| @@ -18,7 +18,7 @@ namespace Tensorflow.Graphs | |||
| public override void OnEntry(MethodExecutionArgs args) | |||
| { | |||
| func_name = $"autograph_{args.Instance.GetHashCode()}.{args.Method.Name}"; | |||
| func_name = $"{args.Method.Name}_{Guid.NewGuid()}"; | |||
| if (functions.ContainsKey(func_name)) | |||
| { | |||
| @@ -34,6 +34,7 @@ namespace Tensorflow.Graphs | |||
| // make function as an Operation by autograph | |||
| // need to restore mode when exits | |||
| function = new ConcreteFunction(func_name); | |||
| function.Enter(); | |||
| // convert to Tensors | |||
| if (args.Arguments[0] is Tensors inputs) | |||
| @@ -68,6 +69,8 @@ namespace Tensorflow.Graphs | |||
| } | |||
| else | |||
| function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), args.ReturnValue as Tensor); | |||
| function.Exit(); | |||
| // cache function. | |||
| function.ReturnType = args.ReturnValue.GetType(); | |||
| @@ -25,63 +25,43 @@ namespace Tensorflow | |||
| /// </summary> | |||
| public class DefaultGraphStack | |||
| { | |||
| private readonly List<StackModel> _stack = new List<StackModel>(); | |||
| private readonly Stack<Graph> _stack = new Stack<Graph>(); | |||
| Graph _global_default_graph; | |||
| public void set_controller(Graph @default) | |||
| public Graph get_default() | |||
| { | |||
| if (!_stack.Exists(x => x.Graph == @default)) | |||
| _stack.Add(new StackModel { Graph = @default, IsDefault = true }); | |||
| if (_stack.Count > 0) | |||
| return _stack.Peek(); | |||
| else if (_global_default_graph != null) | |||
| return _global_default_graph; | |||
| else | |||
| _global_default_graph = new Graph(); | |||
| foreach (var s in _stack) | |||
| s.IsDefault = s.Graph == @default; | |||
| return _global_default_graph; | |||
| } | |||
| public Graph get_controller() | |||
| public Graph get_controller(Graph g) | |||
| { | |||
| if (_stack.Count == 0 || _stack.Count(x => x.IsDefault) == 0) | |||
| _stack.Add(new StackModel { Graph = tf.Graph(), IsDefault = true }); | |||
| for (var i = _stack.Count - 1; i >= 0; i--) | |||
| { | |||
| var x = _stack[i]; | |||
| if (x.IsDefault) | |||
| return x.Graph; | |||
| } | |||
| throw new TensorflowException("Unable to find a default graph"); | |||
| _stack.Push(g); | |||
| return g; | |||
| } | |||
| public Graph peak_controller() | |||
| { | |||
| if (_stack.Count == 0 || _stack.Count(x => x.IsDefault) == 0) | |||
| if (_stack.Count == 0) | |||
| return null; | |||
| for (var i = _stack.Count - 1; i >= 0; i--) | |||
| { | |||
| var x = _stack[i]; | |||
| if (x.IsDefault) | |||
| return x.Graph; | |||
| } | |||
| return null; | |||
| return _stack.Peek(); | |||
| } | |||
| public bool remove(Graph g) | |||
| public void pop() | |||
| { | |||
| if (_stack.Count == 0) | |||
| return false; | |||
| var sm = _stack.Find(model => model.Graph == g); | |||
| return sm != null && _stack.Remove(sm); | |||
| _stack.Pop(); | |||
| } | |||
| public void reset() | |||
| { | |||
| _stack.Clear(); | |||
| } | |||
| private class StackModel | |||
| { | |||
| public Graph Graph { get; set; } | |||
| public bool IsDefault { get; set; } | |||
| _global_default_graph = null; | |||
| } | |||
| } | |||
| } | |||
| @@ -94,8 +94,6 @@ namespace Tensorflow.Graphs | |||
| // mark_as_return | |||
| Outputs = outputs;// .Select(x => array_ops.identity(x)).ToArray(); | |||
| tf.Context.restore_mode(); | |||
| return func_handle; | |||
| } | |||
| @@ -247,9 +245,10 @@ namespace Tensorflow.Graphs | |||
| return this; | |||
| } | |||
| protected override void DisposeManagedResources() | |||
| public override void Exit() | |||
| { | |||
| base.DisposeManagedResources(); | |||
| tf.Context.restore_mode(); | |||
| ops.pop_graph(); | |||
| } | |||
| } | |||
| } | |||
| @@ -146,6 +146,7 @@ namespace Tensorflow | |||
| /// <summary> | |||
| /// Returns a context manager that makes this `Graph` the default graph. | |||
| /// Must call Exit() to pop graph | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public virtual Graph as_default() | |||
| @@ -487,7 +488,7 @@ namespace Tensorflow | |||
| protected override void DisposeManagedResources() | |||
| { | |||
| ops.default_graph_stack.remove(this); | |||
| } | |||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||
| @@ -529,6 +530,11 @@ namespace Tensorflow | |||
| return new TensorShape(dims.Select(x => (int)x).ToArray()); | |||
| } | |||
| public virtual void Exit() | |||
| { | |||
| ops.pop_graph(); | |||
| } | |||
| string debugString = string.Empty; | |||
| public override string ToString() | |||
| { | |||
| @@ -95,7 +95,7 @@ namespace Tensorflow.Graphs | |||
| _copy_non_source(op, graph, op_map, base_graph); | |||
| } | |||
| tf.Context.restore_mode(); | |||
| graph.Exit(); | |||
| return op_map; | |||
| } | |||
| @@ -30,7 +30,7 @@ namespace Tensorflow | |||
| public Operation _apply_op_helper(string op_type_name, string name = null, Dictionary<string, object> keywords = null) | |||
| { | |||
| var g = ops.get_default_graph(); | |||
| var g = ops._get_graph_from_inputs(keywords == null ? new object[0] : keywords.Values.ToArray()); | |||
| var op_def = g.GetOpDef(op_type_name); | |||
| // Default name if not specified. | |||
| @@ -59,7 +59,8 @@ namespace Tensorflow | |||
| var input_types = new List<TF_DataType>(); | |||
| object values = null; | |||
| return tf_with(ops.name_scope(name), scope => | |||
| g.as_default(); | |||
| var ret_op = tf_with(ops.name_scope(name), scope => | |||
| { | |||
| var inferred_from = new Dictionary<string, object>(); | |||
| var base_types = new List<TF_DataType>(); | |||
| @@ -249,6 +250,8 @@ namespace Tensorflow | |||
| return op; | |||
| }); | |||
| g.Exit(); | |||
| return ret_op; | |||
| } | |||
| private void _MaybeColocateWith(ITensorOrOperation[] inputs) | |||
| @@ -78,6 +78,21 @@ namespace Tensorflow | |||
| return get_default_graph().get_collection_ref<T>(key); | |||
| } | |||
| public static Graph _get_graph_from_inputs(params object[] op_input_list) | |||
| { | |||
| var current_default_graph = get_default_graph(); | |||
| if (current_default_graph.building_function) | |||
| return current_default_graph; | |||
| Graph graph = null; | |||
| foreach (var op_input in op_input_list) | |||
| { | |||
| if (op_input is Tensor op_input_tensor) | |||
| graph = graph ?? op_input_tensor.graph; | |||
| } | |||
| return graph ?? current_default_graph; | |||
| } | |||
| public static Graph _get_graph_from_inputs(Tensors op_input_list) | |||
| => _get_graph_from_inputs(op_input_list: op_input_list, graph: null); | |||
| @@ -337,6 +352,11 @@ namespace Tensorflow | |||
| return Interlocked.Increment(ref uid_number); | |||
| } | |||
| public static void reset_uid() | |||
| { | |||
| uid_number = -1; | |||
| } | |||
| public static void colocate_with(bool ignore_existing = false) | |||
| { | |||
| _colocate_with_for_gradient(null, null, ignore_existing); | |||
| @@ -118,16 +118,10 @@ namespace Tensorflow | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public static Graph get_default_graph() | |||
| { | |||
| //return _default_graph_stack.get_default() | |||
| return default_graph_stack.get_controller(); | |||
| } | |||
| => default_graph_stack.get_default(); | |||
| public static Graph set_default_graph(Graph graph) | |||
| { | |||
| default_graph_stack.set_controller(graph); | |||
| return default_graph_stack.get_controller(); | |||
| } | |||
| public static Graph set_default_graph(Graph g) | |||
| => default_graph_stack.get_controller(g); | |||
| /// <summary> | |||
| /// Clears the default graph stack and resets the global default graph. | |||
| @@ -147,5 +141,11 @@ namespace Tensorflow | |||
| // "exit the nesting and create a new graph."); | |||
| default_graph_stack.reset(); | |||
| } | |||
| public static Graph peak_default_graph() | |||
| => default_graph_stack.peak_controller(); | |||
| public static void pop_graph() | |||
| => default_graph_stack.pop(); | |||
| } | |||
| } | |||
| @@ -115,10 +115,10 @@ namespace Tensorflow.Keras | |||
| public void reset_uids() => PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<string, int>>(); | |||
| public void clear_session() | |||
| { | |||
| ops.reset_default_graph(); | |||
| tf.Context.reset_context(); | |||
| reset_uids(); | |||
| ops.set_default_session(tf.Session(ops.get_default_graph())); | |||
| var phase = tf.placeholder_with_default(false, new int[] { }, name: "keras_learning_phase"); | |||
| // var phase = tf.placeholder_with_default(false, new int[] { }, name: "keras_learning_phase"); | |||
| _GRAPH_LEARNING_PHASES = new Dictionary<Graph, GraphLearningPhase>(); | |||
| _GRAPH_LEARNING_PHASES[tf.get_default_graph()] = 0; | |||
| } | |||
| @@ -185,7 +185,7 @@ namespace Tensorflow.Keras | |||
| return tensor_util.constant_value(outputs); | |||
| var source_graph = outputs.graph; | |||
| using var exec_graph = _scratch_graph(); | |||
| var exec_graph = _scratch_graph(); | |||
| var global_graph = get_graph(); | |||
| if (source_graph == global_graph && exec_graph != global_graph) | |||
| { | |||
| @@ -46,7 +46,7 @@ namespace Tensorflow.Keras.Engine | |||
| _set_mask_metadata(inputs, outputs, null); | |||
| }); | |||
| tf.Context.restore_mode(); | |||
| graph.Exit(); | |||
| return outputs; | |||
| } | |||
| @@ -81,8 +81,9 @@ namespace Tensorflow.Keras.Layers | |||
| sparse: args.Sparse, | |||
| ragged: args.Ragged); | |||
| graph.Exit(); | |||
| isPlaceholder = true; | |||
| tf.Context.restore_mode(); | |||
| } | |||
| // Create an input node to add to self.outbound_node | |||
| @@ -5,6 +5,7 @@ | |||
| <Version>0.0.1</Version> | |||
| <Description>TensorFlow Recommenders is a library for building recommender system models using TensorFlow.</Description> | |||
| <PackageLicenseFile>LICENSE</PackageLicenseFile> | |||
| <Platforms>AnyCPU;x64</Platforms> | |||
| </PropertyGroup> | |||
| <ItemGroup> | |||
| @@ -7,12 +7,17 @@ | |||
| <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | |||
| <Version>0.0.1</Version> | |||
| <PackageLicenseFile>LICENSE</PackageLicenseFile> | |||
| <Platforms>AnyCPU;x64</Platforms> | |||
| </PropertyGroup> | |||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||
| <DefineConstants>DEBUG;TRACE</DefineConstants> | |||
| </PropertyGroup> | |||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> | |||
| <DefineConstants>DEBUG;TRACE</DefineConstants> | |||
| </PropertyGroup> | |||
| <ItemGroup> | |||
| <None Include="..\..\LICENSE"> | |||
| <Pack>True</Pack> | |||
| @@ -1,8 +1,7 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using Tensorflow; | |||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||
| namespace Tensorflow.Native.UnitTest | |||
| { | |||
| /// <summary> | |||
| /// tensorflow\c\c_api_test.cc | |||
| @@ -1,9 +1,8 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using System.Runtime.InteropServices; | |||
| using Tensorflow; | |||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||
| namespace Tensorflow.Native.UnitTest | |||
| { | |||
| /// <summary> | |||
| /// tensorflow\c\c_api_test.cc | |||
| @@ -2,10 +2,9 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow; | |||
| using static TensorFlowNET.UnitTest.c_test_util; | |||
| using static Tensorflow.Native.UnitTest.c_test_util; | |||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||
| namespace Tensorflow.Native.UnitTest | |||
| { | |||
| /// <summary> | |||
| /// tensorflow\c\c_api_function_test.cc | |||
| @@ -1,11 +1,9 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using NumSharp; | |||
| using System; | |||
| using Tensorflow; | |||
| using Tensorflow.Util; | |||
| using Buffer = Tensorflow.Buffer; | |||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||
| namespace Tensorflow.Native.UnitTest | |||
| { | |||
| /// <summary> | |||
| /// tensorflow\c\c_api_test.cc | |||
| @@ -1,13 +1,11 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using Tensorflow; | |||
| using Tensorflow.Device; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.UnitTest; | |||
| namespace TensorFlowNET.UnitTest | |||
| namespace Tensorflow.Native.UnitTest | |||
| { | |||
| public class CApiTest : GraphModeTestBase | |||
| public class CApiTest | |||
| { | |||
| protected static readonly TF_Code TF_OK = TF_Code.TF_OK; | |||
| protected static readonly TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT; | |||
| @@ -1,10 +1,9 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow; | |||
| using Tensorflow.Util; | |||
| namespace TensorFlowNET.UnitTest | |||
| namespace Tensorflow.Native.UnitTest | |||
| { | |||
| /// <summary> | |||
| /// tensorflow\c\c_test_util.cc | |||
| @@ -1,9 +1,8 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using Tensorflow; | |||
| using Tensorflow.Device; | |||
| using Tensorflow.Eager; | |||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||
| namespace Tensorflow.Native.UnitTest.Eager | |||
| { | |||
| public partial class CApiEagerTest | |||
| { | |||
| @@ -1,10 +1,9 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using Tensorflow; | |||
| using Tensorflow.Eager; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||
| namespace Tensorflow.Native.UnitTest.Eager | |||
| { | |||
| public partial class CApiEagerTest | |||
| { | |||
| @@ -1,8 +1,7 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using Tensorflow; | |||
| using Tensorflow.Eager; | |||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||
| namespace Tensorflow.Native.UnitTest.Eager | |||
| { | |||
| public partial class CApiEagerTest | |||
| { | |||
| @@ -1,8 +1,7 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using Tensorflow; | |||
| using Tensorflow.Eager; | |||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||
| namespace Tensorflow.Native.UnitTest.Eager | |||
| { | |||
| public partial class CApiEagerTest | |||
| { | |||
| @@ -1,8 +1,7 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using Tensorflow; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||
| namespace Tensorflow.Native.UnitTest.Eager | |||
| { | |||
| public partial class CApiEagerTest | |||
| { | |||
| @@ -1,8 +1,7 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using Tensorflow; | |||
| using Tensorflow.Eager; | |||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||
| namespace Tensorflow.Native.UnitTest.Eager | |||
| { | |||
| public partial class CApiEagerTest | |||
| { | |||
| @@ -1,9 +1,8 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using Tensorflow; | |||
| using Tensorflow.Eager; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||
| namespace Tensorflow.Native.UnitTest.Eager | |||
| { | |||
| public partial class CApiEagerTest | |||
| { | |||
| @@ -1,10 +1,9 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using Tensorflow; | |||
| using Tensorflow.Eager; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||
| namespace Tensorflow.Native.UnitTest.Eager | |||
| { | |||
| /// <summary> | |||
| /// tensorflow\c\eager\c_api_test.cc | |||
| @@ -1,13 +1,12 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using System.Linq; | |||
| using Tensorflow; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.UnitTest.Gradient | |||
| namespace Tensorflow.Native.UnitTest.Eager | |||
| { | |||
| [TestClass] | |||
| public class GradientEagerTest : PythonTest | |||
| public class GradientEagerTest | |||
| { | |||
| [TestMethod] | |||
| public void ConstantSquare() | |||
| @@ -1,8 +1,7 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using Tensorflow; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||
| namespace Tensorflow.Native.UnitTest | |||
| { | |||
| [TestClass] | |||
| public class GraphBuildTest : CApiTest | |||
| @@ -1,10 +1,8 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using Tensorflow; | |||
| using static Tensorflow.Binding; | |||
| using Buffer = Tensorflow.Buffer; | |||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||
| namespace Tensorflow.Native.UnitTest | |||
| { | |||
| [TestClass] | |||
| public class GraphTest : CApiTest | |||
| @@ -0,0 +1,74 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| namespace Tensorflow.Native.UnitTest.Sessions | |||
| { | |||
| [TestClass, Ignore] | |||
| public class SessionTest : CApiTest | |||
| { | |||
| /// <summary> | |||
| /// tensorflow\c\c_api_test.cc | |||
| /// `TEST(CAPI, Session)` | |||
| /// </summary> | |||
| [TestMethod] | |||
| public void Session() | |||
| { | |||
| using var s = new Status(); | |||
| using var graph = new Graph(); | |||
| // Make a placeholder operation. | |||
| var feed = c_test_util.Placeholder(graph, s); | |||
| // Make a constant operation with the scalar "2". | |||
| var two = c_test_util.ScalarConst(2, graph, s); | |||
| // Add operation. | |||
| var add = c_test_util.Add(feed, two, graph, s); | |||
| var csession = new CSession(graph, s); | |||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||
| // Run the graph. | |||
| var inputs = new Dictionary<Operation, Tensor>(); | |||
| inputs.Add(feed, new Tensor(3)); | |||
| csession.SetInputs(inputs); | |||
| var outputs = new TF_Output[] { new TF_Output(add, 0) }; | |||
| csession.SetOutputs(outputs); | |||
| csession.Run(s); | |||
| Tensor outTensor = csession.output_tensor(0); | |||
| EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); | |||
| EXPECT_EQ(0, outTensor.NDims); | |||
| ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); | |||
| var output_contents = outTensor.ToArray<int>(); | |||
| EXPECT_EQ(3 + 2, output_contents[0]); | |||
| // Add another operation to the graph. | |||
| var neg = c_test_util.Neg(add, graph, s); | |||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||
| // Run up to the new operation. | |||
| inputs = new Dictionary<Operation, Tensor>(); | |||
| inputs.Add(feed, new Tensor(7)); | |||
| csession.SetInputs(inputs); | |||
| outputs = new TF_Output[] { new TF_Output(neg, 0) }; | |||
| csession.SetOutputs(outputs); | |||
| csession.Run(s); | |||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||
| outTensor = csession.output_tensor(0); | |||
| ASSERT_TRUE(outTensor != IntPtr.Zero); | |||
| EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); | |||
| EXPECT_EQ(0, outTensor.NDims); // scalar | |||
| ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); | |||
| output_contents = outTensor.ToArray<int>(); | |||
| EXPECT_EQ(-(7 + 2), output_contents[0]); | |||
| // Clean up | |||
| csession.CloseAndDelete(s); | |||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,36 @@ | |||
| <Project Sdk="Microsoft.NET.Sdk"> | |||
| <PropertyGroup> | |||
| <TargetFramework>netcoreapp3.1</TargetFramework> | |||
| <IsPackable>false</IsPackable> | |||
| <Platforms>AnyCPU;x64</Platforms> | |||
| </PropertyGroup> | |||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||
| <AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||
| <DefineConstants>DEBUG;TRACE</DefineConstants> | |||
| <PlatformTarget>x64</PlatformTarget> | |||
| </PropertyGroup> | |||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> | |||
| <AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||
| <DefineConstants>DEBUG;TRACE</DefineConstants> | |||
| <PlatformTarget>x64</PlatformTarget> | |||
| </PropertyGroup> | |||
| <ItemGroup> | |||
| <PackageReference Include="FluentAssertions" Version="5.10.3" /> | |||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.8.3" /> | |||
| <PackageReference Include="MSTest.TestAdapter" Version="2.1.2" /> | |||
| <PackageReference Include="MSTest.TestFramework" Version="2.1.2" /> | |||
| <PackageReference Include="coverlet.collector" Version="1.3.0" /> | |||
| <PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.1" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| <ProjectReference Include="..\..\src\TensorFlowNET.Core\Tensorflow.Binding.csproj" /> | |||
| </ItemGroup> | |||
| </Project> | |||
| @@ -0,0 +1,204 @@ | |||
| using FluentAssertions; | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using NumSharp; | |||
| using System; | |||
| using System.Linq; | |||
| using System.Runtime.InteropServices; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Native.UnitTest.Tensors | |||
| { | |||
| [TestClass] | |||
| public class TensorTest : CApiTest | |||
| { | |||
| [TestMethod] | |||
| public unsafe void TensorFromFixed() | |||
| { | |||
| var array = new float[1000]; | |||
| var span = new Span<float>(array, 100, 500); | |||
| fixed (float* ptr = &MemoryMarshal.GetReference(span)) | |||
| { | |||
| using (var t = new Tensor((IntPtr)ptr, new long[] { span.Length }, tf.float32, 4 * span.Length)) | |||
| { | |||
| Assert.IsFalse(t.IsDisposed); | |||
| Assert.AreEqual(2000, (int)t.bytesize); | |||
| } | |||
| } | |||
| fixed (float* ptr = &array[0]) | |||
| { | |||
| using (var t = new Tensor((IntPtr)ptr, new long[] { array.Length }, tf.float32, 4 * array.Length)) | |||
| { | |||
| Assert.IsFalse(t.IsDisposed); | |||
| Assert.AreEqual(4000, (int)t.bytesize); | |||
| } | |||
| } | |||
| } | |||
| [TestMethod] | |||
| public void TensorFromArray() | |||
| { | |||
| var array = new float[1000]; | |||
| using (var t = new Tensor(array, new long[] { array.Length }, tf.float32)) | |||
| { | |||
| Assert.IsFalse(t.IsDisposed); | |||
| Assert.AreEqual(1000 * sizeof(float), (int)t.bytesize); | |||
| } | |||
| using (var t = new Tensor(new float[] { 1 }, new long[] { 1 }, tf.float32)) | |||
| { | |||
| Assert.IsFalse(t.IsDisposed); | |||
| Assert.AreEqual(1 * sizeof(float), (int)t.bytesize); | |||
| } | |||
| using (var t = new Tensor(new float[] { 1 }, null, tf.float32)) | |||
| { | |||
| Assert.IsFalse(t.IsDisposed); | |||
| Assert.AreEqual(1 * sizeof(float), (int)t.bytesize); | |||
| t.shape.Should().BeEmpty(); | |||
| } | |||
| } | |||
| [TestMethod] | |||
| public void AllocateTensor() | |||
| { | |||
| ulong num_bytes = 6 * sizeof(float); | |||
| long[] dims = { 2, 3 }; | |||
| Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes); | |||
| EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype); | |||
| EXPECT_EQ(2, t.NDims); | |||
| EXPECT_EQ((int)dims[0], t.shape[0]); | |||
| EXPECT_EQ(num_bytes, t.bytesize); | |||
| t.Dispose(); | |||
| } | |||
| /// <summary> | |||
| /// Port from c_api_test.cc | |||
| /// `TEST(CAPI, MaybeMove)` | |||
| /// </summary> | |||
| [TestMethod, Ignore] | |||
| public void MaybeMove() | |||
| { | |||
| NDArray nd = np.array(2, 3); | |||
| Tensor t = new Tensor(nd); | |||
| Tensor o = t.MaybeMove(); | |||
| ASSERT_TRUE(o == IntPtr.Zero); // It is unsafe to move memory TF might not own. | |||
| t.Dispose(); | |||
| } | |||
| /// <summary> | |||
| /// Port from c_api_test.cc | |||
| /// `TEST(CAPI, Tensor)` | |||
| /// </summary> | |||
| [TestMethod] | |||
| public void Tensor() | |||
| { | |||
| var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); | |||
| var tensor = new Tensor(nd); | |||
| var array = tensor.ToArray<float>(); | |||
| EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT); | |||
| EXPECT_EQ(tensor.rank, nd.ndim); | |||
| EXPECT_EQ((int)tensor.shape[0], nd.shape[0]); | |||
| EXPECT_EQ((int)tensor.shape[1], nd.shape[1]); | |||
| EXPECT_EQ(tensor.bytesize, (ulong)nd.size * sizeof(float)); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), new float[] { 1, 2, 3, 4, 5, 6 })); | |||
| } | |||
| /// <summary> | |||
| /// Port from tensorflow\c\c_api_test.cc | |||
| /// `TEST(CAPI, SetShape)` | |||
| /// </summary> | |||
| [TestMethod] | |||
| public void SetShape() | |||
| { | |||
| var s = new Status(); | |||
| var graph = new Graph().as_default(); | |||
| var feed = c_test_util.Placeholder(graph, s); | |||
| var feed_out_0 = new TF_Output(feed, 0); | |||
| // Fetch the shape, it should be completely unknown. | |||
| int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| EXPECT_EQ(-1, num_dims); | |||
| // Set the shape to be unknown, expect no change. | |||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s.Handle); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle); | |||
| EXPECT_EQ(-1, num_dims); | |||
| // Set the shape to be 2 x Unknown | |||
| long[] dims = { 2, -1 }; | |||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s.Handle); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle); | |||
| EXPECT_EQ(2, num_dims); | |||
| // Get the dimension vector appropriately. | |||
| var returned_dims = new long[dims.Length]; | |||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | |||
| // Set to a new valid shape: [2, 3] | |||
| dims[1] = 3; | |||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s.Handle); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| // Fetch and see that the new value is returned. | |||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | |||
| // Try to set 'unknown' with unknown rank on the shape and see that | |||
| // it doesn't change. | |||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s.Handle); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| EXPECT_EQ(2, num_dims); | |||
| EXPECT_EQ(2, (int)returned_dims[0]); | |||
| EXPECT_EQ(3, (int)returned_dims[1]); | |||
| // Try to set 'unknown' with same rank on the shape and see that | |||
| // it doesn't change. | |||
| dims[0] = -1; | |||
| dims[1] = -1; | |||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s.Handle); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| EXPECT_EQ(2, num_dims); | |||
| EXPECT_EQ(2, (int)returned_dims[0]); | |||
| EXPECT_EQ(3, (int)returned_dims[1]); | |||
| // Try to fetch a shape with the wrong num_dims | |||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s.Handle); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); | |||
| // Try to set an invalid shape (cannot change 2x3 to a 2x5). | |||
| dims[1] = 5; | |||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s.Handle); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); | |||
| // Test for a scalar. | |||
| var three = c_test_util.ScalarConst(3, graph, s); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| var three_out_0 = new TF_Output(three, 0); | |||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s.Handle); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| EXPECT_EQ(0, num_dims); | |||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, dims, num_dims, s.Handle); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); | |||
| graph.Exit(); | |||
| s.Dispose(); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,10 +1,8 @@ | |||
| using System; | |||
| using System.Diagnostics.CodeAnalysis; | |||
| using Tensorflow; | |||
| using Tensorflow.Util; | |||
| using Buffer = Tensorflow.Buffer; | |||
| namespace TensorFlowNET.UnitTest | |||
| namespace Tensorflow.Native.UnitTest | |||
| { | |||
| /// <summary> | |||
| /// Port from `tensorflow\c\c_test_util.cc` | |||
| @@ -8,78 +8,11 @@ using Tensorflow; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||
| namespace TensorFlowNET.UnitTest | |||
| { | |||
| [TestClass] | |||
| public class SessionTest : CApiTest | |||
| [TestClass, Ignore] | |||
| public class SessionTest | |||
| { | |||
| /// <summary> | |||
| /// tensorflow\c\c_api_test.cc | |||
| /// `TEST(CAPI, Session)` | |||
| /// </summary> | |||
| [TestMethod, Ignore] | |||
| public void Session() | |||
| { | |||
| lock (Locks.ProcessWide) | |||
| { | |||
| var s = new Status(); | |||
| var graph = new Graph().as_default(); | |||
| // Make a placeholder operation. | |||
| var feed = c_test_util.Placeholder(graph, s); | |||
| // Make a constant operation with the scalar "2". | |||
| var two = c_test_util.ScalarConst(2, graph, s); | |||
| // Add operation. | |||
| var add = c_test_util.Add(feed, two, graph, s); | |||
| var csession = new CSession(graph, s); | |||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||
| // Run the graph. | |||
| var inputs = new Dictionary<Operation, Tensor>(); | |||
| inputs.Add(feed, new Tensor(3)); | |||
| csession.SetInputs(inputs); | |||
| var outputs = new TF_Output[] { new TF_Output(add, 0) }; | |||
| csession.SetOutputs(outputs); | |||
| csession.Run(s); | |||
| Tensor outTensor = csession.output_tensor(0); | |||
| EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); | |||
| EXPECT_EQ(0, outTensor.NDims); | |||
| ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); | |||
| var output_contents = outTensor.ToArray<int>(); | |||
| EXPECT_EQ(3 + 2, output_contents[0]); | |||
| // Add another operation to the graph. | |||
| var neg = c_test_util.Neg(add, graph, s); | |||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||
| // Run up to the new operation. | |||
| inputs = new Dictionary<Operation, Tensor>(); | |||
| inputs.Add(feed, new Tensor(7)); | |||
| csession.SetInputs(inputs); | |||
| outputs = new TF_Output[] { new TF_Output(neg, 0) }; | |||
| csession.SetOutputs(outputs); | |||
| csession.Run(s); | |||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||
| outTensor = csession.output_tensor(0); | |||
| ASSERT_TRUE(outTensor != IntPtr.Zero); | |||
| EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); | |||
| EXPECT_EQ(0, outTensor.NDims); // scalar | |||
| ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); | |||
| output_contents = outTensor.ToArray<int>(); | |||
| EXPECT_EQ(-(7 + 2), output_contents[0]); | |||
| // Clean up | |||
| csession.CloseAndDelete(s); | |||
| ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||
| } | |||
| } | |||
| [TestMethod] | |||
| public void EvalTensor() | |||
| { | |||
| @@ -7,201 +7,11 @@ using System.Runtime.InteropServices; | |||
| using Tensorflow; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||
| namespace TensorFlowNET.UnitTest | |||
| { | |||
| [TestClass] | |||
| public class TensorTest : CApiTest | |||
| [TestClass, Ignore] | |||
| public class TensorTest | |||
| { | |||
| [TestMethod] | |||
| public unsafe void TensorFromFixed() | |||
| { | |||
| var array = new float[1000]; | |||
| var span = new Span<float>(array, 100, 500); | |||
| fixed (float* ptr = &MemoryMarshal.GetReference(span)) | |||
| { | |||
| using (var t = new Tensor((IntPtr)ptr, new long[] { span.Length }, tf.float32, 4 * span.Length)) | |||
| { | |||
| Assert.IsFalse(t.IsDisposed); | |||
| Assert.AreEqual(2000, (int)t.bytesize); | |||
| } | |||
| } | |||
| fixed (float* ptr = &array[0]) | |||
| { | |||
| using (var t = new Tensor((IntPtr)ptr, new long[] { array.Length }, tf.float32, 4 * array.Length)) | |||
| { | |||
| Assert.IsFalse(t.IsDisposed); | |||
| Assert.AreEqual(4000, (int)t.bytesize); | |||
| } | |||
| } | |||
| } | |||
| [TestMethod] | |||
| public unsafe void TensorFromArray() | |||
| { | |||
| var array = new float[1000]; | |||
| using (var t = new Tensor(array, new long[] { array.Length }, tf.float32)) | |||
| { | |||
| Assert.IsFalse(t.IsDisposed); | |||
| Assert.AreEqual(1000 * sizeof(float), (int)t.bytesize); | |||
| } | |||
| using (var t = new Tensor(new float[] { 1 }, new long[] { 1 }, tf.float32)) | |||
| { | |||
| Assert.IsFalse(t.IsDisposed); | |||
| Assert.AreEqual(1 * sizeof(float), (int)t.bytesize); | |||
| } | |||
| using (var t = new Tensor(new float[] { 1 }, null, tf.float32)) | |||
| { | |||
| Assert.IsFalse(t.IsDisposed); | |||
| Assert.AreEqual(1 * sizeof(float), (int)t.bytesize); | |||
| t.shape.Should().BeEmpty(); | |||
| } | |||
| } | |||
| [TestMethod] | |||
| public void AllocateTensor() | |||
| { | |||
| ulong num_bytes = 6 * sizeof(float); | |||
| long[] dims = { 2, 3 }; | |||
| Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes); | |||
| EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype); | |||
| EXPECT_EQ(2, t.NDims); | |||
| EXPECT_EQ((int)dims[0], t.shape[0]); | |||
| EXPECT_EQ(num_bytes, t.bytesize); | |||
| t.Dispose(); | |||
| } | |||
| /// <summary> | |||
| /// Port from c_api_test.cc | |||
| /// `TEST(CAPI, MaybeMove)` | |||
| /// </summary> | |||
| [TestMethod, Ignore] | |||
| public void MaybeMove() | |||
| { | |||
| NDArray nd = np.array(2, 3); | |||
| Tensor t = new Tensor(nd); | |||
| Tensor o = t.MaybeMove(); | |||
| ASSERT_TRUE(o == IntPtr.Zero); // It is unsafe to move memory TF might not own. | |||
| t.Dispose(); | |||
| } | |||
| /// <summary> | |||
| /// Port from c_api_test.cc | |||
| /// `TEST(CAPI, Tensor)` | |||
| /// </summary> | |||
| [TestMethod] | |||
| public void Tensor() | |||
| { | |||
| var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); | |||
| var tensor = new Tensor(nd); | |||
| var array = tensor.ToArray<float>(); | |||
| EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT); | |||
| EXPECT_EQ(tensor.rank, nd.ndim); | |||
| EXPECT_EQ((int)tensor.shape[0], nd.shape[0]); | |||
| EXPECT_EQ((int)tensor.shape[1], nd.shape[1]); | |||
| EXPECT_EQ(tensor.bytesize, (ulong)nd.size * sizeof(float)); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), new float[] { 1, 2, 3, 4, 5, 6 })); | |||
| } | |||
| /// <summary> | |||
| /// Port from tensorflow\c\c_api_test.cc | |||
| /// `TEST(CAPI, SetShape)` | |||
| /// </summary> | |||
| [TestMethod] | |||
| public void SetShape() | |||
| { | |||
| var s = new Status(); | |||
| var graph = new Graph().as_default(); | |||
| var feed = c_test_util.Placeholder(graph, s); | |||
| var feed_out_0 = new TF_Output(feed, 0); | |||
| // Fetch the shape, it should be completely unknown. | |||
| int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| EXPECT_EQ(-1, num_dims); | |||
| // Set the shape to be unknown, expect no change. | |||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s.Handle); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle); | |||
| EXPECT_EQ(-1, num_dims); | |||
| // Set the shape to be 2 x Unknown | |||
| long[] dims = { 2, -1 }; | |||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s.Handle); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle); | |||
| EXPECT_EQ(2, num_dims); | |||
| // Get the dimension vector appropriately. | |||
| var returned_dims = new long[dims.Length]; | |||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | |||
| // Set to a new valid shape: [2, 3] | |||
| dims[1] = 3; | |||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s.Handle); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| // Fetch and see that the new value is returned. | |||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | |||
| // Try to set 'unknown' with unknown rank on the shape and see that | |||
| // it doesn't change. | |||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s.Handle); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| EXPECT_EQ(2, num_dims); | |||
| EXPECT_EQ(2, (int)returned_dims[0]); | |||
| EXPECT_EQ(3, (int)returned_dims[1]); | |||
| // Try to set 'unknown' with same rank on the shape and see that | |||
| // it doesn't change. | |||
| dims[0] = -1; | |||
| dims[1] = -1; | |||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s.Handle); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| EXPECT_EQ(2, num_dims); | |||
| EXPECT_EQ(2, (int)returned_dims[0]); | |||
| EXPECT_EQ(3, (int)returned_dims[1]); | |||
| // Try to fetch a shape with the wrong num_dims | |||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s.Handle); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); | |||
| // Try to set an invalid shape (cannot change 2x3 to a 2x5). | |||
| dims[1] = 5; | |||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s.Handle); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); | |||
| // Test for a scalar. | |||
| var three = c_test_util.ScalarConst(3, graph, s); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| var three_out_0 = new TF_Output(three, 0); | |||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s.Handle); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| EXPECT_EQ(0, num_dims); | |||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, null, num_dims, s.Handle); | |||
| //Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| // graph.Dispose(); | |||
| s.Dispose(); | |||
| } | |||
| [TestMethod] | |||
| public void sparse_to_dense() | |||
| { | |||
| @@ -271,32 +81,5 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, masked.ToArray<int>())); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Creates a tensor from an image of 256x256x3 and resizes it to 100x100x3 | |||
| /// </summary> | |||
| [TestMethod] | |||
| public unsafe void tensor_resize() | |||
| { | |||
| tf.enable_eager_execution(); | |||
| var imageArray = new float[256 * 256 * 3]; | |||
| using var newSize = tf.convert_to_tensor(new int[] { 100, 100 }); | |||
| using (var t = tf.constant(imageArray, tf.float32, (1, 256, 256, 3))) | |||
| { | |||
| Assert.IsFalse(t.IsDisposed); | |||
| Assert.AreEqual(256 * 256 * 3 * sizeof(float), (int)t.bytesize); | |||
| using var resized = tf.image.resize_bilinear(t, newSize); | |||
| EXPECT_EQ(resized.shape[0], 1); | |||
| EXPECT_EQ(resized.shape[1], 100); | |||
| EXPECT_EQ(resized.shape[2], 100); | |||
| EXPECT_EQ(resized.shape[3], 3); | |||
| } | |||
| tf.compat.v1.disable_eager_execution(); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,4 +1,5 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.UnitTest | |||
| @@ -10,11 +11,40 @@ namespace TensorFlowNET.UnitTest | |||
| { | |||
| if (!tf.executing_eagerly()) | |||
| tf.enable_eager_execution(); | |||
| tf.Context.ensure_initialized(); | |||
| } | |||
| [TestCleanup] | |||
| public void TestClean() | |||
| { | |||
| } | |||
| public bool Equal(float[] f1, float[] f2) | |||
| { | |||
| bool ret = false; | |||
| var tolerance = .000001f; | |||
| for (var i = 0; i < f1.Length; i++) | |||
| { | |||
| ret = Math.Abs(f1[i] - f2[i]) <= tolerance; | |||
| if (!ret) | |||
| break; | |||
| } | |||
| return ret; | |||
| } | |||
| public bool Equal(double[] d1, double[] d2) | |||
| { | |||
| bool ret = false; | |||
| var tolerance = .000000000000001f; | |||
| for (var i = 0; i < d1.Length; i++) | |||
| { | |||
| ret = Math.Abs(d1[i] - d2[i]) <= tolerance; | |||
| if (!ret) | |||
| break; | |||
| } | |||
| return ret; | |||
| } | |||
| } | |||
| } | |||
| @@ -6,10 +6,10 @@ using System.Threading; | |||
| using Tensorflow; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.UnitTest.NativeAPI | |||
| namespace TensorFlowNET.UnitTest | |||
| { | |||
| [TestClass] | |||
| public class EnforcedSinglethreadingTests : CApiTest | |||
| public class EnforcedSinglethreadingTests | |||
| { | |||
| private static readonly object _singlethreadLocker = new object(); | |||
| @@ -1,6 +1,7 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using TensorFlowNET.UnitTest; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.KerasApi; | |||
| namespace Tensorflow.UnitTest | |||
| { | |||
| @@ -15,6 +16,7 @@ namespace Tensorflow.UnitTest | |||
| [TestCleanup] | |||
| public void TestClean() | |||
| { | |||
| keras.backend.clear_session(); | |||
| tf.enable_eager_execution(); | |||
| } | |||
| } | |||
| @@ -5,7 +5,7 @@ using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.UnitTest.nn_test | |||
| { | |||
| [TestClass] | |||
| public class ActivationFunctionTest : TFNetApiTest | |||
| public class ActivationFunctionTest : EagerModeTestBase | |||
| { | |||
| // A constant vector of size 6 | |||
| Tensor a = tf.constant(new float[] { 1.0f, -0.5f, 3.4f, -2.1f, 0.0f, -6.5f }); | |||
| @@ -6,7 +6,7 @@ using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.UnitTest.ManagedAPI | |||
| { | |||
| [TestClass] | |||
| public class BitwiseApiTest : TFNetApiTest | |||
| public class BitwiseApiTest : EagerModeTestBase | |||
| { | |||
| [TestInitialize] | |||
| public void Init() | |||
| @@ -7,7 +7,7 @@ using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.UnitTest.ManagedAPI | |||
| { | |||
| [TestClass] | |||
| public class FunctionApiTest : TFNetApiTest | |||
| public class FunctionApiTest : EagerModeTestBase | |||
| { | |||
| Tensor Min(Tensor a, Tensor b) | |||
| { | |||
| @@ -6,7 +6,7 @@ using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.UnitTest.ManagedAPI | |||
| { | |||
| [TestClass] | |||
| public class MathApiTest : TFNetApiTest | |||
| public class MathApiTest : EagerModeTestBase | |||
| { | |||
| // A constant vector of size 6 | |||
| Tensor a = tf.constant(new float[] { 1.0f, -0.5f, 3.4f, -2.1f, 0.0f, -6.5f }); | |||
| @@ -1,35 +0,0 @@ | |||
| using System; | |||
| namespace TensorFlowNET.UnitTest | |||
| { | |||
| public class TFNetApiTest | |||
| { | |||
| public bool Equal(float[] f1, float[] f2) | |||
| { | |||
| bool ret = false; | |||
| var tolerance = .000001f; | |||
| for (var i = 0; i < f1.Length; i++) | |||
| { | |||
| ret = Math.Abs(f1[i] - f2[i]) <= tolerance; | |||
| if (!ret) | |||
| break; | |||
| } | |||
| return ret; | |||
| } | |||
| public bool Equal(double[] d1, double[] d2) | |||
| { | |||
| bool ret = false; | |||
| var tolerance = .000000000000001f; | |||
| for (var i = 0; i < d1.Length; i++) | |||
| { | |||
| ret = Math.Abs(d1[i] - d2[i]) <= tolerance; | |||
| if (!ret) | |||
| break; | |||
| } | |||
| return ret; | |||
| } | |||
| } | |||
| } | |||
| @@ -1,86 +0,0 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using NumSharp; | |||
| using System; | |||
| using System.Linq; | |||
| using Tensorflow; | |||
| using Tensorflow.UnitTest; | |||
| namespace TensorFlowNET.UnitTest.nn_test | |||
| { | |||
| [TestClass] | |||
| public class ZeroFractionTest : GraphModeTestBase | |||
| { | |||
| protected double _ZeroFraction(NDArray x) | |||
| { | |||
| assert(x.shape); | |||
| int total_elements = np.prod(x.shape); | |||
| var eps = 1e-8; | |||
| var nonzeros = x.Data<double>().Count(d => Math.Abs(d) > eps); | |||
| return 1.0 - nonzeros / (double)total_elements; | |||
| } | |||
| [Ignore("TODO implement nn_impl.zero_fraction")] | |||
| [TestMethod] | |||
| public void testZeroFraction() | |||
| { | |||
| var x_shape = new Shape(5, 17); | |||
| var x_np = np.random.randint(0, 2, x_shape); | |||
| //x_np.astype(np.float32); | |||
| var y_np = this._ZeroFraction(x_np); | |||
| var x_tf = constant_op.constant(x_np); | |||
| x_tf.set_shape(x_shape); | |||
| var y_tf = nn_impl.zero_fraction(x_tf); | |||
| var y_tf_np = self.evaluate<NDArray>(y_tf); | |||
| var eps = 1e-8; | |||
| self.assertAllClose(y_tf_np, y_np, eps); | |||
| } | |||
| [Ignore("TODO implement nn_impl.zero_fraction")] | |||
| [TestMethod] | |||
| public void testZeroFractionEmpty() | |||
| { | |||
| var x = np.zeros(0); | |||
| var y = self.evaluate<NDArray>(nn_impl.zero_fraction(new Tensor(x))); | |||
| self.assertTrue(np.isnan(y)); | |||
| } | |||
| [Ignore("TODO implement nn_impl.zero_fraction")] | |||
| [TestMethod] | |||
| public void testZeroFraction2_27Zeros() | |||
| { | |||
| var sparsity = nn_impl.zero_fraction( | |||
| array_ops.zeros(new Shape((int)Math.Pow(2, 27 * 1.01)), dtypes.int8)); | |||
| self.assertAllClose(1.0, self.evaluate<NDArray>(sparsity)); | |||
| } | |||
| [Ignore("TODO implement nn_impl.zero_fraction")] | |||
| [TestMethod] | |||
| public void testZeroFraction2_27Ones() | |||
| { | |||
| var sparsity = nn_impl.zero_fraction( | |||
| array_ops.ones(new TensorShape((int)Math.Pow(2, 27 * 1.01)), dtypes.int8)); | |||
| self.assertAllClose(0.0, self.evaluate<NDArray>(sparsity)); | |||
| } | |||
| [Ignore("TODO implement nn_impl.zero_fraction")] | |||
| [TestMethod] | |||
| public void testUnknownSize() | |||
| { | |||
| var value = array_ops.placeholder(dtype: dtypes.float32); | |||
| var sparsity = nn_impl.zero_fraction(value); | |||
| using (var sess = self.cached_session()) | |||
| { | |||
| // TODO: make this compile | |||
| //self.assertAllClose( | |||
| // 0.25, | |||
| // sess.run(sparsity, {value: [[0., 1.], [0.3, 2.]]})); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -73,6 +73,7 @@ namespace TensorFlowNET.UnitTest | |||
| tf.peak_default_graph().Should().BeNull(); | |||
| var beforehand = tf.get_default_graph(); //this should create default automatically. | |||
| beforehand.graph_key.Should().NotContain("-0/", "Already created a graph in an other thread."); | |||
| beforehand.as_default(); | |||
| tf.peak_default_graph().Should().NotBeNull(); | |||
| using (var sess = tf.Session()) | |||
| @@ -1,4 +1,5 @@ | |||
| using System.IO; | |||
| using System; | |||
| using System.IO; | |||
| namespace TensorFlowNET.UnitTest | |||
| { | |||
| @@ -6,8 +7,16 @@ namespace TensorFlowNET.UnitTest | |||
| { | |||
| public static string GetFullPathFromDataDir(string fileName) | |||
| { | |||
| var dir = Path.Combine(Directory.GetCurrentDirectory(), "..", "..", "..", "..", "..", "data"); | |||
| return Path.GetFullPath(Path.Combine(dir, fileName)); | |||
| var dataDir = GetRootContentDir(Directory.GetCurrentDirectory()); | |||
| return Path.Combine(dataDir, fileName); | |||
| } | |||
| static string GetRootContentDir(string dir) | |||
| { | |||
| var path = Path.GetFullPath(Path.Combine(dir, "data")); | |||
| if (Directory.Exists(path)) | |||
| return path; | |||
| return GetRootContentDir(Path.GetFullPath(Path.Combine(dir, ".."))); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,875 +0,0 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using Newtonsoft.Json.Linq; | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections; | |||
| using System.Collections.Generic; | |||
| using Tensorflow; | |||
| using Tensorflow.UnitTest; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.UnitTest.nest_test | |||
| { | |||
| /// <summary> | |||
| /// excerpt of tensorflow/python/framework/util/nest_test.py | |||
| /// </summary> | |||
| [TestClass] | |||
| public class NestTest : GraphModeTestBase | |||
| { | |||
| [TestInitialize] | |||
| public void TestInitialize() | |||
| { | |||
| tf.Graph().as_default(); | |||
| } | |||
| //public class PointXY | |||
| //{ | |||
| // public double x; | |||
| // public double y; | |||
| //} | |||
| // if attr: | |||
| // class BadAttr(object): | |||
| // """Class that has a non-iterable __attrs_attrs__.""" | |||
| // __attrs_attrs__ = None | |||
| // @attr.s | |||
| // class SampleAttr(object): | |||
| // field1 = attr.ib() | |||
| // field2 = attr.ib() | |||
| // @test_util.assert_no_new_pyobjects_executing_eagerly | |||
| // def testAttrsFlattenAndPack(self) : | |||
| // if attr is None: | |||
| // self.skipTest("attr module is unavailable.") | |||
| // field_values = [1, 2] | |||
| // sample_attr = NestTest.SampleAttr(* field_values) | |||
| // self.assertFalse(nest._is_attrs(field_values)) | |||
| // self.assertTrue(nest._is_attrs(sample_attr)) | |||
| // flat = nest.flatten(sample_attr) | |||
| // self.assertEqual(field_values, flat) | |||
| // restructured_from_flat = nest.pack_sequence_as(sample_attr, flat) | |||
| // self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr) | |||
| // self.assertEqual(restructured_from_flat, sample_attr) | |||
| //# Check that flatten fails if attributes are not iterable | |||
| // with self.assertRaisesRegexp(TypeError, "object is not iterable"): | |||
| // flat = nest.flatten(NestTest.BadAttr()) | |||
| [Ignore] | |||
| [TestMethod] | |||
| public void testFlattenAndPack() | |||
| { | |||
| object structure = new object[] { new object[] { 3, 4 }, 5, new object[] { 6, 7, new object[] { 9, 10 }, 8 } }; | |||
| var flat = new List<object> { "a", "b", "c", "d", "e", "f", "g", "h" }; | |||
| self.assertEqual(nest.flatten(structure), new[] { 3, 4, 5, 6, 7, 9, 10, 8 }); | |||
| self.assertEqual(JArray.FromObject(nest.pack_sequence_as(structure, flat)).ToString(), | |||
| JArray.FromObject(new object[] { new object[] { "a", "b" }, "c", new object[] { "d", "e", new object[] { "f", "g" }, "h" } }).ToString()); | |||
| structure = new object[] { new Hashtable { ["x"] = 4, ["y"] = 2 }, new object[] { new object[] { new Hashtable { ["x"] = 1, ["y"] = 0 }, }, } }; | |||
| flat = new List<object> { 4, 2, 1, 0 }; | |||
| self.assertEqual(nest.flatten(structure), flat); | |||
| var restructured_from_flat = nest.pack_sequence_as(structure, flat) as object[]; | |||
| //Console.WriteLine(JArray.FromObject(restructured_from_flat)); | |||
| self.assertEqual(restructured_from_flat, structure); | |||
| self.assertEqual((restructured_from_flat[0] as Hashtable)["x"], 4); | |||
| self.assertEqual((restructured_from_flat[0] as Hashtable)["y"], 2); | |||
| self.assertEqual((((restructured_from_flat[1] as object[])[0] as object[])[0] as Hashtable)["x"], 1); | |||
| self.assertEqual((((restructured_from_flat[1] as object[])[0] as object[])[0] as Hashtable)["y"], 0); | |||
| self.assertEqual(new List<object> { 5 }, nest.flatten(5)); | |||
| var flat1 = nest.flatten(np.array(new[] { 5 })); | |||
| self.assertEqual(new object[] { np.array(new int[] { 5 }) }, flat1); | |||
| self.assertEqual("a", nest.pack_sequence_as(5, new List<object> { "a" })); | |||
| self.assertEqual(np.array(new[] { 5 }), | |||
| nest.pack_sequence_as("scalar", new List<object> { np.array(new[] { 5 }) })); | |||
| Assert.ThrowsException<ValueError>(() => nest.pack_sequence_as("scalar", new List<object>() { 4, 5 })); | |||
| Assert.ThrowsException<ValueError>(() => | |||
| nest.pack_sequence_as(new object[] { 5, 6, new object[] { 7, 8 } }, new List<object> { "a", "b", "c" })); | |||
| } | |||
| // @parameterized.parameters({"mapping_type": collections.OrderedDict | |||
| // }, | |||
| // {"mapping_type": _CustomMapping | |||
| //}) | |||
| // @test_util.assert_no_new_pyobjects_executing_eagerly | |||
| // def testFlattenDictOrder(self, mapping_type) : | |||
| // """`flatten` orders dicts by key, including OrderedDicts.""" | |||
| // ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) | |||
| // plain = {"d": 3, "b": 1, "a": 0, "c": 2} | |||
| // ordered_flat = nest.flatten(ordered) | |||
| // plain_flat = nest.flatten(plain) | |||
| // self.assertEqual([0, 1, 2, 3], ordered_flat) | |||
| // self.assertEqual([0, 1, 2, 3], plain_flat) | |||
| // @parameterized.parameters({"mapping_type": collections.OrderedDict}, | |||
| // {"mapping_type": _CustomMapping}) | |||
| // def testPackDictOrder(self, mapping_type): | |||
| // """Packing orders dicts by key, including OrderedDicts.""" | |||
| // custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)]) | |||
| // plain = {"d": 0, "b": 0, "a": 0, "c": 0} | |||
| // seq = [0, 1, 2, 3] | |||
| //custom_reconstruction = nest.pack_sequence_as(custom, seq) | |||
| //plain_reconstruction = nest.pack_sequence_as(plain, seq) | |||
| // self.assertIsInstance(custom_reconstruction, mapping_type) | |||
| // self.assertIsInstance(plain_reconstruction, dict) | |||
| // self.assertEqual( | |||
| // mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]), | |||
| // custom_reconstruction) | |||
| // self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction) | |||
| // Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name | |||
| // @test_util.assert_no_new_pyobjects_executing_eagerly | |||
| // def testFlattenAndPack_withDicts(self) : | |||
| // # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. | |||
| // mess = [ | |||
| // "z", | |||
| // NestTest.Abc(3, 4), { | |||
| // "d": _CustomMapping({ | |||
| // 41: 4 | |||
| // }), | |||
| // "c": [ | |||
| // 1, | |||
| // collections.OrderedDict([ | |||
| // ("b", 3), | |||
| // ("a", 2), | |||
| // ]), | |||
| // ], | |||
| // "b": 5 | |||
| // }, 17 | |||
| // ] | |||
| // flattened = nest.flatten(mess) | |||
| // self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17]) | |||
| // structure_of_mess = [ | |||
| // 14, | |||
| // NestTest.Abc("a", True), | |||
| // { | |||
| // "d": _CustomMapping({ | |||
| // 41: 42 | |||
| // }), | |||
| // "c": [ | |||
| // 0, | |||
| // collections.OrderedDict([ | |||
| // ("b", 9), | |||
| // ("a", 8), | |||
| // ]), | |||
| // ], | |||
| // "b": 3 | |||
| // }, | |||
| // "hi everybody", | |||
| // ] | |||
| // unflattened = nest.pack_sequence_as(structure_of_mess, flattened) | |||
| // self.assertEqual(unflattened, mess) | |||
| // # Check also that the OrderedDict was created, with the correct key order. | |||
| //unflattened_ordered_dict = unflattened[2]["c"][1] | |||
| // self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict) | |||
| // self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"]) | |||
| // unflattened_custom_mapping = unflattened[2]["d"] | |||
| // self.assertIsInstance(unflattened_custom_mapping, _CustomMapping) | |||
| // self.assertEqual(list(unflattened_custom_mapping.keys()), [41]) | |||
| [TestMethod] | |||
| public void testFlatten_numpyIsNotFlattened() | |||
| { | |||
| var structure = np.array(1, 2, 3); | |||
| var flattened = nest.flatten(structure); | |||
| self.assertEqual(len(flattened), 1); | |||
| } | |||
| [TestMethod] | |||
| public void testFlatten_stringIsNotFlattened() | |||
| { | |||
| var structure = "lots of letters"; | |||
| var flattened = nest.flatten(structure); | |||
| self.assertEqual(len(flattened), 1); | |||
| var unflattened = nest.pack_sequence_as("goodbye", flattened); | |||
| self.assertEqual(structure, unflattened); | |||
| } | |||
| // def testPackSequenceAs_notIterableError(self) : | |||
| // with self.assertRaisesRegexp(TypeError, | |||
| // "flat_sequence must be a sequence"): | |||
| // nest.pack_sequence_as("hi", "bye") | |||
| [TestMethod] | |||
| public void testPackSequenceAs_wrongLengthsError() | |||
| { | |||
| Assert.ThrowsException<ValueError>(() => | |||
| { | |||
| // with self.assertRaisesRegexp( | |||
| // ValueError, | |||
| // "Structure had 2 elements, but flat_sequence had 3 elements."): | |||
| nest.pack_sequence_as(new object[] { "hello", "world" }, new object[] { "and", "goodbye", "again" }); | |||
| }); | |||
| } | |||
| [Ignore] | |||
| [TestMethod] | |||
| public void testIsSequence() | |||
| { | |||
| self.assertFalse(nest.is_sequence("1234")); | |||
| self.assertTrue(nest.is_sequence(new object[] { 1, 3, new object[] { 4, 5 } })); | |||
| // TODO: ValueTuple<T,T> | |||
| //self.assertTrue(nest.is_sequence(((7, 8), (5, 6)))); | |||
| self.assertTrue(nest.is_sequence(new object[] { })); | |||
| self.assertTrue(nest.is_sequence(new Hashtable { ["a"] = 1, ["b"] = 2 })); | |||
| self.assertFalse(nest.is_sequence(new HashSet<int> { 1, 2 })); | |||
| var ones = array_ops.ones(new int[] { 2, 3 }); | |||
| self.assertFalse(nest.is_sequence(ones)); | |||
| self.assertFalse(nest.is_sequence(gen_math_ops.tanh(ones))); | |||
| self.assertFalse(nest.is_sequence(np.ones(new int[] { 4, 5 }))); | |||
| } | |||
| // @parameterized.parameters({"mapping_type": _CustomMapping}, | |||
| // {"mapping_type": dict}) | |||
| // def testFlattenDictItems(self, mapping_type): | |||
| // dictionary = mapping_type({ (4, 5, (6, 8)): ("a", "b", ("c", "d"))}) | |||
| // flat = {4: "a", 5: "b", 6: "c", 8: "d"} | |||
| // self.assertEqual(nest.flatten_dict_items(dictionary), flat) | |||
| // with self.assertRaises(TypeError): | |||
| // nest.flatten_dict_items(4) | |||
| // bad_dictionary = mapping_type({ (4, 5, (4, 8)): ("a", "b", ("c", "d"))}) | |||
| // with self.assertRaisesRegexp(ValueError, "not unique"): | |||
| // nest.flatten_dict_items(bad_dictionary) | |||
| // another_bad_dictionary = mapping_type({ | |||
| // (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e"))) | |||
| // }) | |||
| // with self.assertRaisesRegexp( | |||
| // ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"): | |||
| // nest.flatten_dict_items(another_bad_dictionary) | |||
| //# pylint does not correctly recognize these as class names and | |||
| //# suggests to use variable style under_score naming. | |||
| //# pylint: disable=invalid-name | |||
| // Named0ab = collections.namedtuple("named_0", ("a", "b")) | |||
| // Named1ab = collections.namedtuple("named_1", ("a", "b")) | |||
| // SameNameab = collections.namedtuple("same_name", ("a", "b")) | |||
| // SameNameab2 = collections.namedtuple("same_name", ("a", "b")) | |||
| // SameNamexy = collections.namedtuple("same_name", ("x", "y")) | |||
| // SameName1xy = collections.namedtuple("same_name_1", ("x", "y")) | |||
| // SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y")) | |||
| // NotSameName = collections.namedtuple("not_same_name", ("a", "b")) | |||
| // # pylint: enable=invalid-name | |||
| // class SameNamedType1(SameNameab): | |||
| // pass | |||
| // @test_util.assert_no_new_pyobjects_executing_eagerly | |||
| // def testAssertSameStructure(self): | |||
| // structure1 = (((1, 2), 3), 4, (5, 6)) | |||
| // structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) | |||
| // structure_different_num_elements = ("spam", "eggs") | |||
| // structure_different_nesting = (((1, 2), 3), 4, 5, (6,)) | |||
| // nest.assert_same_structure(structure1, structure2) | |||
| // nest.assert_same_structure("abc", 1.0) | |||
| // nest.assert_same_structure("abc", np.array([0, 1])) | |||
| // nest.assert_same_structure("abc", constant_op.constant([0, 1])) | |||
| // with self.assertRaisesRegexp( | |||
| // ValueError, | |||
| // ("The two structures don't have the same nested structure\\.\n\n" | |||
| // "First structure:.*?\n\n" | |||
| // "Second structure:.*\n\n" | |||
| // "More specifically: Substructure " | |||
| // r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while ' | |||
| // 'substructure "type=str str=spam" is not\n' | |||
| // "Entire first structure:\n" | |||
| // r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n" | |||
| // "Entire second structure:\n" | |||
| // r"\(\., \.\)")): | |||
| // nest.assert_same_structure(structure1, structure_different_num_elements) | |||
| // with self.assertRaisesRegexp( | |||
| // ValueError, | |||
| // ("The two structures don't have the same nested structure\\.\n\n" | |||
| // "First structure:.*?\n\n" | |||
| // "Second structure:.*\n\n" | |||
| // r'More specifically: Substructure "type=list str=\[0, 1\]" ' | |||
| // r'is a sequence, while substructure "type=ndarray str=\[0 1\]" ' | |||
| // "is not")): | |||
| // nest.assert_same_structure([0, 1], np.array([0, 1])) | |||
| // with self.assertRaisesRegexp( | |||
| // ValueError, | |||
| // ("The two structures don't have the same nested structure\\.\n\n" | |||
| // "First structure:.*?\n\n" | |||
| // "Second structure:.*\n\n" | |||
| // r'More specifically: Substructure "type=list str=\[0, 1\]" ' | |||
| // 'is a sequence, while substructure "type=int str=0" ' | |||
| // "is not")): | |||
| // nest.assert_same_structure(0, [0, 1]) | |||
| // self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1]) | |||
| // with self.assertRaisesRegexp( | |||
| // ValueError, | |||
| // ("don't have the same nested structure\\.\n\n" | |||
| // "First structure: .*?\n\nSecond structure: ")): | |||
| // nest.assert_same_structure(structure1, structure_different_nesting) | |||
| // self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), | |||
| // NestTest.Named0ab("a", "b")) | |||
| // nest.assert_same_structure(NestTest.Named0ab(3, 4), | |||
| // NestTest.Named0ab("a", "b")) | |||
| // self.assertRaises(TypeError, nest.assert_same_structure, | |||
| // NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4)) | |||
| // with self.assertRaisesRegexp( | |||
| // ValueError, | |||
| // ("don't have the same nested structure\\.\n\n" | |||
| // "First structure: .*?\n\nSecond structure: ")): | |||
| // nest.assert_same_structure(NestTest.Named0ab(3, 4), | |||
| // NestTest.Named0ab([3], 4)) | |||
| // with self.assertRaisesRegexp( | |||
| // ValueError, | |||
| // ("don't have the same nested structure\\.\n\n" | |||
| // "First structure: .*?\n\nSecond structure: ")): | |||
| // nest.assert_same_structure([[3], 4], [3, [4]]) | |||
| // structure1_list = [[[1, 2], 3], 4, [5, 6]] | |||
| // with self.assertRaisesRegexp(TypeError, | |||
| // "don't have the same sequence type"): | |||
| // nest.assert_same_structure(structure1, structure1_list) | |||
| // nest.assert_same_structure(structure1, structure2, check_types= False) | |||
| // nest.assert_same_structure(structure1, structure1_list, check_types=False) | |||
| // with self.assertRaisesRegexp(ValueError, | |||
| // "don't have the same set of keys"): | |||
| // nest.assert_same_structure({"a": 1}, {"b": 1}) | |||
| // nest.assert_same_structure(NestTest.SameNameab(0, 1), | |||
| // NestTest.SameNameab2(2, 3)) | |||
| // # This assertion is expected to pass: two namedtuples with the same | |||
| // # name and field names are considered to be identical. | |||
| // nest.assert_same_structure( | |||
| // NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2), | |||
| // NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4)) | |||
| // expected_message = "The two structures don't have the same.*" | |||
| // with self.assertRaisesRegexp(ValueError, expected_message): | |||
| // nest.assert_same_structure( | |||
| // NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)), | |||
| // NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2)) | |||
| // self.assertRaises(TypeError, nest.assert_same_structure, | |||
| // NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3)) | |||
| // self.assertRaises(TypeError, nest.assert_same_structure, | |||
| // NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3)) | |||
| // self.assertRaises(TypeError, nest.assert_same_structure, | |||
| // NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3)) | |||
| // EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name | |||
| // def testHeterogeneousComparison(self): | |||
| // nest.assert_same_structure({"a": 4}, _CustomMapping(a= 3)) | |||
| // nest.assert_same_structure(_CustomMapping(b=3), {"b": 4}) | |||
| [Ignore] | |||
| [TestMethod] | |||
| public void testMapStructure() | |||
| { | |||
| var structure1 = new object[] { new object[] { new object[] { 1, 2 }, 3 }, 4, new object[] { 5, 6 } }; | |||
| var structure2 = new object[] { new object[] { new object[] { 7, 8 }, 9 }, 10, new object[] { 11, 12 } }; | |||
| var structure1_plus1 = nest.map_structure(x => (int)x + 1, structure1); | |||
| var structure1_strings = nest.map_structure(x => $"{x}", structure1); | |||
| var s = JArray.FromObject(structure1_plus1).ToString(); | |||
| Console.WriteLine(s); | |||
| // nest.assert_same_structure(structure1, structure1_plus1) | |||
| self.assertAllEqual(nest.flatten(structure1_plus1), new object[] { 2, 3, 4, 5, 6, 7 }); | |||
| self.assertAllEqual(nest.flatten(structure1_strings), new object[] { "1", "2", "3", "4", "5", "6" }); | |||
| var structure1_plus_structure2 = nest.map_structure(x => (int)(x[0]) + (int)(x[1]), structure1, structure2); | |||
| self.assertEqual( | |||
| new object[] { new object[] { new object[] { 1 + 7, 2 + 8 }, 3 + 9 }, 4 + 10, new object[] { 5 + 11, 6 + 12 } }, | |||
| structure1_plus_structure2); | |||
| // self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4)) | |||
| // self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4)) | |||
| // # Empty structures | |||
| // self.assertEqual((), nest.map_structure(lambda x: x + 1, ())) | |||
| // self.assertEqual([], nest.map_structure(lambda x: x + 1, [])) | |||
| // self.assertEqual({}, nest.map_structure(lambda x: x + 1, {})) | |||
| // self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1, | |||
| // NestTest.EmptyNT())) | |||
| // # This is checking actual equality of types, empty list != empty tuple | |||
| // self.assertNotEqual((), nest.map_structure(lambda x: x + 1, [])) | |||
| // with self.assertRaisesRegexp(TypeError, "callable"): | |||
| // nest.map_structure("bad", structure1_plus1) | |||
| // with self.assertRaisesRegexp(ValueError, "at least one structure"): | |||
| // nest.map_structure(lambda x: x) | |||
| // with self.assertRaisesRegexp(ValueError, "same number of elements"): | |||
| // nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5)) | |||
| // with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||
| // nest.map_structure(lambda x, y: None, 3, (3,)) | |||
| // with self.assertRaisesRegexp(TypeError, "same sequence type"): | |||
| // nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5]) | |||
| // with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||
| // nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5))) | |||
| // structure1_list = [[[1, 2], 3], 4, [5, 6]] | |||
| // with self.assertRaisesRegexp(TypeError, "same sequence type"): | |||
| // nest.map_structure(lambda x, y: None, structure1, structure1_list) | |||
| // nest.map_structure(lambda x, y: None, structure1, structure1_list, | |||
| // check_types=False) | |||
| // with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||
| // nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)), | |||
| // check_types=False) | |||
| // with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): | |||
| // nest.map_structure(lambda x: None, structure1, foo="a") | |||
| // with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): | |||
| // nest.map_structure(lambda x: None, structure1, check_types=False, foo="a") | |||
| // ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name | |||
| } | |||
| // @test_util.assert_no_new_pyobjects_executing_eagerly | |||
| // def testMapStructureWithStrings(self) : | |||
| // inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz")) | |||
| // inp_b = NestTest.ABTuple(a=2, b=(1, 3)) | |||
| // out = nest.map_structure(lambda string, repeats: string* repeats, | |||
| // inp_a, | |||
| // inp_b) | |||
| // self.assertEqual("foofoo", out.a) | |||
| // self.assertEqual("bar", out.b[0]) | |||
| // self.assertEqual("bazbazbaz", out.b[1]) | |||
| // nt = NestTest.ABTuple(a=("something", "something_else"), | |||
| // b="yet another thing") | |||
| // rev_nt = nest.map_structure(lambda x: x[::- 1], nt) | |||
| // # Check the output is the correct structure, and all strings are reversed. | |||
| // nest.assert_same_structure(nt, rev_nt) | |||
| // self.assertEqual(nt.a[0][::- 1], rev_nt.a[0]) | |||
| // self.assertEqual(nt.a[1][::- 1], rev_nt.a[1]) | |||
| // self.assertEqual(nt.b[::- 1], rev_nt.b) | |||
| // @test_util.run_deprecated_v1 | |||
| // def testMapStructureOverPlaceholders(self) : | |||
| // inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), | |||
| // array_ops.placeholder(dtypes.float32, shape=[3, 7])) | |||
| // inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), | |||
| // array_ops.placeholder(dtypes.float32, shape=[3, 7])) | |||
| // output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b) | |||
| // nest.assert_same_structure(output, inp_a) | |||
| // self.assertShapeEqual(np.zeros((3, 4)), output[0]) | |||
| // self.assertShapeEqual(np.zeros((3, 7)), output[1]) | |||
| // feed_dict = { | |||
| // inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)), | |||
| // inp_b: (np.random.randn(3, 4), np.random.randn(3, 7)) | |||
| // } | |||
| // with self.cached_session() as sess: | |||
| // output_np = sess.run(output, feed_dict=feed_dict) | |||
| // self.assertAllClose(output_np[0], | |||
| // feed_dict[inp_a][0] + feed_dict[inp_b][0]) | |||
| // self.assertAllClose(output_np[1], | |||
| // feed_dict[inp_a][1] + feed_dict[inp_b][1]) | |||
| // def testAssertShallowStructure(self): | |||
| // inp_ab = ["a", "b"] | |||
| //inp_abc = ["a", "b", "c"] | |||
| //expected_message = ( | |||
| // "The two structures don't have the same sequence length. Input " | |||
| // "structure has length 2, while shallow structure has length 3.") | |||
| // with self.assertRaisesRegexp(ValueError, expected_message): | |||
| // nest.assert_shallow_structure(inp_abc, inp_ab) | |||
| // inp_ab1 = [(1, 1), (2, 2)] | |||
| // inp_ab2 = [[1, 1], [2, 2]] | |||
| // expected_message = ( | |||
| // "The two structures don't have the same sequence type. Input structure " | |||
| // "has type <(type|class) 'tuple'>, while shallow structure has type " | |||
| // "<(type|class) 'list'>.") | |||
| // with self.assertRaisesRegexp(TypeError, expected_message): | |||
| // nest.assert_shallow_structure(inp_ab2, inp_ab1) | |||
| // nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types= False) | |||
| // inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}} | |||
| // inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}} | |||
| // expected_message = ( | |||
| // r"The two structures don't have the same keys. Input " | |||
| // r"structure has keys \['c'\], while shallow structure has " | |||
| // r"keys \['d'\].") | |||
| // with self.assertRaisesRegexp(ValueError, expected_message): | |||
| // nest.assert_shallow_structure(inp_ab2, inp_ab1) | |||
| // inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))]) | |||
| // inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)]) | |||
| // nest.assert_shallow_structure(inp_ab, inp_ba) | |||
| // # This assertion is expected to pass: two namedtuples with the same | |||
| //# name and field names are considered to be identical. | |||
| //inp_shallow = NestTest.SameNameab(1, 2) | |||
| // inp_deep = NestTest.SameNameab2(1, [1, 2, 3]) | |||
| // nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False) | |||
| // nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True) | |||
| // def testFlattenUpTo(self): | |||
| // # Shallow tree ends at scalar. | |||
| // input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] | |||
| // shallow_tree = [[True, True], [False, True]] | |||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| // self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) | |||
| // self.assertEqual(flattened_shallow_tree, [True, True, False, True]) | |||
| //# Shallow tree ends at string. | |||
| // input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] | |||
| // shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] | |||
| // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
| // input_tree) | |||
| // input_tree_flattened = nest.flatten(input_tree) | |||
| // self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
| // [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) | |||
| // self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) | |||
| // # Make sure dicts are correctly flattened, yielding values, not keys. | |||
| //input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} | |||
| // shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} | |||
| // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
| // input_tree) | |||
| // self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
| // [1, { "c": 2}, 3, (4, 5)]) | |||
| // # Namedtuples. | |||
| // ab_tuple = NestTest.ABTuple | |||
| // input_tree = ab_tuple(a =[0, 1], b = 2) | |||
| // shallow_tree = ab_tuple(a= 0, b= 1) | |||
| // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
| // input_tree) | |||
| // self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
| // [[0, 1], 2]) | |||
| // # Nested dicts, OrderedDicts and namedtuples. | |||
| // input_tree = collections.OrderedDict( | |||
| // [("a", ab_tuple(a =[0, {"b": 1}], b=2)), | |||
| // ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})]) | |||
| // shallow_tree = input_tree | |||
| // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
| // input_tree) | |||
| // self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) | |||
| // shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})]) | |||
| // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
| // input_tree) | |||
| // self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
| // [ab_tuple(a =[0, { "b": 1}], b=2), | |||
| // 3, | |||
| // collections.OrderedDict([("f", 4)])]) | |||
| // shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) | |||
| // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
| // input_tree) | |||
| // self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
| // [ab_tuple(a =[0, {"b": 1}], b=2), | |||
| // {"d": 3, "e": collections.OrderedDict([("f", 4)])}]) | |||
| // ## Shallow non-list edge-case. | |||
| // # Using iterable elements. | |||
| // input_tree = ["input_tree"] | |||
| //shallow_tree = "shallow_tree" | |||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| // self.assertEqual(flattened_input_tree, [input_tree]) | |||
| // self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
| // input_tree = ["input_tree_0", "input_tree_1"] | |||
| //shallow_tree = "shallow_tree" | |||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| // self.assertEqual(flattened_input_tree, [input_tree]) | |||
| // self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
| // # Using non-iterable elements. | |||
| //input_tree = [0] | |||
| //shallow_tree = 9 | |||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| // self.assertEqual(flattened_input_tree, [input_tree]) | |||
| // self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
| // input_tree = [0, 1] | |||
| //shallow_tree = 9 | |||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| // self.assertEqual(flattened_input_tree, [input_tree]) | |||
| // self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
| // ## Both non-list edge-case. | |||
| //# Using iterable elements. | |||
| //input_tree = "input_tree" | |||
| // shallow_tree = "shallow_tree" | |||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| // self.assertEqual(flattened_input_tree, [input_tree]) | |||
| // self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
| // # Using non-iterable elements. | |||
| //input_tree = 0 | |||
| // shallow_tree = 0 | |||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| // self.assertEqual(flattened_input_tree, [input_tree]) | |||
| // self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
| // ## Input non-list edge-case. | |||
| //# Using iterable elements. | |||
| //input_tree = "input_tree" | |||
| // shallow_tree = ["shallow_tree"] | |||
| //expected_message = ("If shallow structure is a sequence, input must also " | |||
| // "be a sequence. Input has type: <(type|class) 'str'>.") | |||
| // with self.assertRaisesRegexp(TypeError, expected_message): | |||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| // self.assertEqual(flattened_shallow_tree, shallow_tree) | |||
| // input_tree = "input_tree" | |||
| // shallow_tree = ["shallow_tree_9", "shallow_tree_8"] | |||
| //with self.assertRaisesRegexp(TypeError, expected_message): | |||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| // self.assertEqual(flattened_shallow_tree, shallow_tree) | |||
| //# Using non-iterable elements. | |||
| // input_tree = 0 | |||
| // shallow_tree = [9] | |||
| //expected_message = ("If shallow structure is a sequence, input must also " | |||
| // "be a sequence. Input has type: <(type|class) 'int'>.") | |||
| // with self.assertRaisesRegexp(TypeError, expected_message): | |||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| // self.assertEqual(flattened_shallow_tree, shallow_tree) | |||
| // input_tree = 0 | |||
| // shallow_tree = [9, 8] | |||
| //with self.assertRaisesRegexp(TypeError, expected_message): | |||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| // self.assertEqual(flattened_shallow_tree, shallow_tree) | |||
| // def testMapStructureUpTo(self) : | |||
| // # Named tuples. | |||
| // ab_tuple = collections.namedtuple("ab_tuple", "a, b") | |||
| // op_tuple = collections.namedtuple("op_tuple", "add, mul") | |||
| // inp_val = ab_tuple(a= 2, b= 3) | |||
| // inp_ops = ab_tuple(a= op_tuple(add = 1, mul = 2), b= op_tuple(add = 2, mul = 3)) | |||
| // out = nest.map_structure_up_to( | |||
| // inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops) | |||
| // self.assertEqual(out.a, 6) | |||
| // self.assertEqual(out.b, 15) | |||
| // # Lists. | |||
| // data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] | |||
| // name_list = ["evens", ["odds", "primes"]] | |||
| // out = nest.map_structure_up_to( | |||
| // name_list, lambda name, sec: "first_{}_{}".format(len(sec), name), | |||
| // name_list, data_list) | |||
| // self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]]) | |||
| // # Dicts. | |||
| // inp_val = dict(a= 2, b= 3) | |||
| // inp_ops = dict(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3)) | |||
| // out = nest.map_structure_up_to( | |||
| // inp_val, | |||
| // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||
| // self.assertEqual(out["a"], 6) | |||
| // self.assertEqual(out["b"], 15) | |||
| // # Non-equal dicts. | |||
| // inp_val = dict(a= 2, b= 3) | |||
| // inp_ops = dict(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3)) | |||
| // with self.assertRaisesRegexp(ValueError, "same keys"): | |||
| // nest.map_structure_up_to( | |||
| // inp_val, | |||
| // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||
| // # Dict+custom mapping. | |||
| // inp_val = dict(a= 2, b= 3) | |||
| // inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3)) | |||
| // out = nest.map_structure_up_to( | |||
| // inp_val, | |||
| // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||
| // self.assertEqual(out["a"], 6) | |||
| // self.assertEqual(out["b"], 15) | |||
| // # Non-equal dict/mapping. | |||
| // inp_val = dict(a= 2, b= 3) | |||
| // inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3)) | |||
| // with self.assertRaisesRegexp(ValueError, "same keys"): | |||
| // nest.map_structure_up_to( | |||
| // inp_val, | |||
| // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||
| // def testGetTraverseShallowStructure(self): | |||
| // scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []] | |||
| // scalar_traverse_r = nest.get_traverse_shallow_structure( | |||
| // lambda s: not isinstance(s, tuple), | |||
| // scalar_traverse_input) | |||
| // self.assertEqual(scalar_traverse_r, | |||
| // [True, True, False, [True, True], {"a": False}, []]) | |||
| // nest.assert_shallow_structure(scalar_traverse_r, | |||
| // scalar_traverse_input) | |||
| // structure_traverse_input = [(1, [2]), ([1], 2)] | |||
| // structure_traverse_r = nest.get_traverse_shallow_structure( | |||
| // lambda s: (True, False) if isinstance(s, tuple) else True, | |||
| // structure_traverse_input) | |||
| // self.assertEqual(structure_traverse_r, | |||
| // [(True, False), ([True], False)]) | |||
| // nest.assert_shallow_structure(structure_traverse_r, | |||
| // structure_traverse_input) | |||
| // with self.assertRaisesRegexp(TypeError, "returned structure"): | |||
| // nest.get_traverse_shallow_structure(lambda _: [True], 0) | |||
| // with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"): | |||
| // nest.get_traverse_shallow_structure(lambda _: 1, [1]) | |||
| // with self.assertRaisesRegexp( | |||
| // TypeError, "didn't return a depth=1 structure of bools"): | |||
| // nest.get_traverse_shallow_structure(lambda _: [1], [1]) | |||
| // def testYieldFlatStringPaths(self): | |||
| // for inputs_expected in ({"inputs": [], "expected": []}, | |||
| // {"inputs": 3, "expected": [()]}, | |||
| // {"inputs": [3], "expected": [(0,)]}, | |||
| // {"inputs": {"a": 3}, "expected": [("a",)]}, | |||
| // {"inputs": {"a": {"b": 4}}, | |||
| // "expected": [("a", "b")]}, | |||
| // {"inputs": [{"a": 2}], "expected": [(0, "a")]}, | |||
| // {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]}, | |||
| // {"inputs": [{"a": [(23, 42)]}], | |||
| // "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]}, | |||
| // {"inputs": [{"a": ([23], 42)}], | |||
| // "expected": [(0, "a", 0, 0), (0, "a", 1)]}, | |||
| // {"inputs": {"a": {"a": 2}, "c": [[[4]]]}, | |||
| // "expected": [("a", "a"), ("c", 0, 0, 0)]}, | |||
| // {"inputs": {"0": [{"1": 23}]}, | |||
| // "expected": [("0", 0, "1")]}): | |||
| // inputs = inputs_expected["inputs"] | |||
| // expected = inputs_expected["expected"] | |||
| // self.assertEqual(list(nest.yield_flat_paths(inputs)), expected) | |||
| // def testFlattenWithStringPaths(self): | |||
| // for inputs_expected in ( | |||
| // {"inputs": [], "expected": []}, | |||
| // {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]}, | |||
| // {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}): | |||
| // inputs = inputs_expected["inputs"] | |||
| // expected = inputs_expected["expected"] | |||
| // self.assertEqual( | |||
| // nest.flatten_with_joined_string_paths(inputs, separator="/"), | |||
| // expected) | |||
| // # Need a separate test for namedtuple as we can't declare tuple definitions | |||
| // # in the @parameterized arguments. | |||
| // def testFlattenNamedTuple(self): | |||
| // # pylint: disable=invalid-name | |||
| // Foo = collections.namedtuple("Foo", ["a", "b"]) | |||
| // Bar = collections.namedtuple("Bar", ["c", "d"]) | |||
| // # pylint: enable=invalid-name | |||
| // test_cases = [ | |||
| // (Foo(a = 3, b = Bar(c = 23, d = 42)), | |||
| // [("a", 3), ("b/c", 23), ("b/d", 42)]), | |||
| // (Foo(a = Bar(c = 23, d = 42), b = Bar(c = 0, d = "something")), | |||
| // [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]), | |||
| // (Bar(c = 42, d = 43), | |||
| // [("c", 42), ("d", 43)]), | |||
| // (Bar(c =[42], d = 43), | |||
| // [("c/0", 42), ("d", 43)]), | |||
| // ] | |||
| // for inputs, expected in test_cases: | |||
| // self.assertEqual( | |||
| // list(nest.flatten_with_joined_string_paths(inputs)), expected) | |||
| // @parameterized.named_parameters( | |||
| // ("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))), | |||
| // ("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True, | |||
| // {"a": ("a", 4), "b": ("b", 6)}), | |||
| // ("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))), | |||
| // ("nested", | |||
| // {"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True, | |||
| // {"a": [("a/0", 10), ("a/1", 12)], | |||
| // "b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]})) | |||
| // def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected): | |||
| // def format_sum(path, * values): | |||
| // return (path, sum(values)) | |||
| // result = nest.map_structure_with_paths(format_sum, s1, s2, | |||
| // check_types=check_types) | |||
| // self.assertEqual(expected, result) | |||
| // @parameterized.named_parameters( | |||
| // ("tuples", (1, 2), (3, 4, 5), ValueError), | |||
| // ("dicts", {"a": 1}, {"b": 2}, ValueError), | |||
| // ("mixed", (1, 2), [3, 4], TypeError), | |||
| // ("nested", | |||
| // {"a": [2, 3], "b": [1, 3]}, | |||
| // {"b": [5, 6, 7], "a": [8, 9]}, | |||
| // ValueError | |||
| // )) | |||
| // def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type): | |||
| // with self.assertRaises(error_type): | |||
| // nest.map_structure_with_paths(lambda path, * s: 0, s1, s2) | |||
| //class NestBenchmark(test.Benchmark): | |||
| // def run_and_report(self, s1, s2, name): | |||
| // burn_iter, test_iter = 100, 30000 | |||
| // for _ in xrange(burn_iter) : | |||
| // nest.assert_same_structure(s1, s2) | |||
| // t0 = time.time() | |||
| // for _ in xrange(test_iter) : | |||
| // nest.assert_same_structure(s1, s2) | |||
| // t1 = time.time() | |||
| // self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter, | |||
| // name=name) | |||
| // def benchmark_assert_structure(self): | |||
| // s1 = (((1, 2), 3), 4, (5, 6)) | |||
| // s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) | |||
| // self.run_and_report(s1, s2, "assert_same_structure_6_elem") | |||
| // s1 = (((1, 2), 3), 4, (5, 6)) * 10 | |||
| // s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10 | |||
| // self.run_and_report(s1, s2, "assert_same_structure_60_elem") | |||
| //if __name__ == "__main__": | |||
| // test.main() | |||
| } | |||
| } | |||
| @@ -1,883 +0,0 @@ | |||
| # Copyright 2016 The TensorFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Tests for utilities working with arbitrarily nested structures.""" | |||
| from __future__ import absolute_import | |||
| from __future__ import division | |||
| from __future__ import print_function | |||
| import collections | |||
| import time | |||
| from absl.testing import parameterized | |||
| import numpy as np | |||
| from six.moves import xrange # pylint: disable=redefined-builtin | |||
| from tensorflow.python.framework import constant_op | |||
| from tensorflow.python.framework import dtypes | |||
| from tensorflow.python.framework import test_util | |||
| from tensorflow.python.ops import array_ops | |||
| from tensorflow.python.ops import math_ops | |||
| from tensorflow.python.platform import test | |||
| from tensorflow.python.util import nest | |||
| try: | |||
| import attr # pylint:disable=g-import-not-at-top | |||
| except ImportError: | |||
| attr = None | |||
| class _CustomMapping(collections.Mapping): | |||
| def __init__(self, *args, **kwargs): | |||
| self._wrapped = dict(*args, **kwargs) | |||
| def __getitem__(self, key): | |||
| return self._wrapped[key] | |||
| def __iter__(self): | |||
| return iter(self._wrapped) | |||
| def __len__(self): | |||
| return len(self._wrapped) | |||
| class NestTest(parameterized.TestCase, test.TestCase): | |||
| PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name | |||
| if attr: | |||
| class BadAttr(object): | |||
| """Class that has a non-iterable __attrs_attrs__.""" | |||
| __attrs_attrs__ = None | |||
| @attr.s | |||
| class SampleAttr(object): | |||
| field1 = attr.ib() | |||
| field2 = attr.ib() | |||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||
| def testAttrsFlattenAndPack(self): | |||
| if attr is None: | |||
| self.skipTest("attr module is unavailable.") | |||
| field_values = [1, 2] | |||
| sample_attr = NestTest.SampleAttr(*field_values) | |||
| self.assertFalse(nest._is_attrs(field_values)) | |||
| self.assertTrue(nest._is_attrs(sample_attr)) | |||
| flat = nest.flatten(sample_attr) | |||
| self.assertEqual(field_values, flat) | |||
| restructured_from_flat = nest.pack_sequence_as(sample_attr, flat) | |||
| self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr) | |||
| self.assertEqual(restructured_from_flat, sample_attr) | |||
| # Check that flatten fails if attributes are not iterable | |||
| with self.assertRaisesRegexp(TypeError, "object is not iterable"): | |||
| flat = nest.flatten(NestTest.BadAttr()) | |||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||
| def testFlattenAndPack(self): | |||
| structure = ((3, 4), 5, (6, 7, (9, 10), 8)) | |||
| flat = ["a", "b", "c", "d", "e", "f", "g", "h"] | |||
| self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8]) | |||
| self.assertEqual( | |||
| nest.pack_sequence_as(structure, flat), (("a", "b"), "c", | |||
| ("d", "e", ("f", "g"), "h"))) | |||
| structure = (NestTest.PointXY(x=4, y=2), | |||
| ((NestTest.PointXY(x=1, y=0),),)) | |||
| flat = [4, 2, 1, 0] | |||
| self.assertEqual(nest.flatten(structure), flat) | |||
| restructured_from_flat = nest.pack_sequence_as(structure, flat) | |||
| self.assertEqual(restructured_from_flat, structure) | |||
| self.assertEqual(restructured_from_flat[0].x, 4) | |||
| self.assertEqual(restructured_from_flat[0].y, 2) | |||
| self.assertEqual(restructured_from_flat[1][0][0].x, 1) | |||
| self.assertEqual(restructured_from_flat[1][0][0].y, 0) | |||
| self.assertEqual([5], nest.flatten(5)) | |||
| self.assertEqual([np.array([5])], nest.flatten(np.array([5]))) | |||
| self.assertEqual("a", nest.pack_sequence_as(5, ["a"])) | |||
| self.assertEqual( | |||
| np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])])) | |||
| with self.assertRaisesRegexp(ValueError, "Structure is a scalar"): | |||
| nest.pack_sequence_as("scalar", [4, 5]) | |||
| with self.assertRaisesRegexp(TypeError, "flat_sequence"): | |||
| nest.pack_sequence_as([4, 5], "bad_sequence") | |||
| with self.assertRaises(ValueError): | |||
| nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"]) | |||
| @parameterized.parameters({"mapping_type": collections.OrderedDict}, | |||
| {"mapping_type": _CustomMapping}) | |||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||
| def testFlattenDictOrder(self, mapping_type): | |||
| """`flatten` orders dicts by key, including OrderedDicts.""" | |||
| ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) | |||
| plain = {"d": 3, "b": 1, "a": 0, "c": 2} | |||
| ordered_flat = nest.flatten(ordered) | |||
| plain_flat = nest.flatten(plain) | |||
| self.assertEqual([0, 1, 2, 3], ordered_flat) | |||
| self.assertEqual([0, 1, 2, 3], plain_flat) | |||
| @parameterized.parameters({"mapping_type": collections.OrderedDict}, | |||
| {"mapping_type": _CustomMapping}) | |||
| def testPackDictOrder(self, mapping_type): | |||
| """Packing orders dicts by key, including OrderedDicts.""" | |||
| custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)]) | |||
| plain = {"d": 0, "b": 0, "a": 0, "c": 0} | |||
| seq = [0, 1, 2, 3] | |||
| custom_reconstruction = nest.pack_sequence_as(custom, seq) | |||
| plain_reconstruction = nest.pack_sequence_as(plain, seq) | |||
| self.assertIsInstance(custom_reconstruction, mapping_type) | |||
| self.assertIsInstance(plain_reconstruction, dict) | |||
| self.assertEqual( | |||
| mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]), | |||
| custom_reconstruction) | |||
| self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction) | |||
| Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name | |||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||
| def testFlattenAndPack_withDicts(self): | |||
| # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. | |||
| mess = [ | |||
| "z", | |||
| NestTest.Abc(3, 4), { | |||
| "d": _CustomMapping({ | |||
| 41: 4 | |||
| }), | |||
| "c": [ | |||
| 1, | |||
| collections.OrderedDict([ | |||
| ("b", 3), | |||
| ("a", 2), | |||
| ]), | |||
| ], | |||
| "b": 5 | |||
| }, 17 | |||
| ] | |||
| flattened = nest.flatten(mess) | |||
| self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17]) | |||
| structure_of_mess = [ | |||
| 14, | |||
| NestTest.Abc("a", True), | |||
| { | |||
| "d": _CustomMapping({ | |||
| 41: 42 | |||
| }), | |||
| "c": [ | |||
| 0, | |||
| collections.OrderedDict([ | |||
| ("b", 9), | |||
| ("a", 8), | |||
| ]), | |||
| ], | |||
| "b": 3 | |||
| }, | |||
| "hi everybody", | |||
| ] | |||
| unflattened = nest.pack_sequence_as(structure_of_mess, flattened) | |||
| self.assertEqual(unflattened, mess) | |||
| # Check also that the OrderedDict was created, with the correct key order. | |||
| unflattened_ordered_dict = unflattened[2]["c"][1] | |||
| self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict) | |||
| self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"]) | |||
| unflattened_custom_mapping = unflattened[2]["d"] | |||
| self.assertIsInstance(unflattened_custom_mapping, _CustomMapping) | |||
| self.assertEqual(list(unflattened_custom_mapping.keys()), [41]) | |||
| def testFlatten_numpyIsNotFlattened(self): | |||
| structure = np.array([1, 2, 3]) | |||
| flattened = nest.flatten(structure) | |||
| self.assertEqual(len(flattened), 1) | |||
| def testFlatten_stringIsNotFlattened(self): | |||
| structure = "lots of letters" | |||
| flattened = nest.flatten(structure) | |||
| self.assertEqual(len(flattened), 1) | |||
| unflattened = nest.pack_sequence_as("goodbye", flattened) | |||
| self.assertEqual(structure, unflattened) | |||
| def testPackSequenceAs_notIterableError(self): | |||
| with self.assertRaisesRegexp(TypeError, | |||
| "flat_sequence must be a sequence"): | |||
| nest.pack_sequence_as("hi", "bye") | |||
| def testPackSequenceAs_wrongLengthsError(self): | |||
| with self.assertRaisesRegexp( | |||
| ValueError, | |||
| "Structure had 2 elements, but flat_sequence had 3 elements."): | |||
| nest.pack_sequence_as(["hello", "world"], | |||
| ["and", "goodbye", "again"]) | |||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||
| def testIsSequence(self): | |||
| self.assertFalse(nest.is_sequence("1234")) | |||
| self.assertTrue(nest.is_sequence([1, 3, [4, 5]])) | |||
| self.assertTrue(nest.is_sequence(((7, 8), (5, 6)))) | |||
| self.assertTrue(nest.is_sequence([])) | |||
| self.assertTrue(nest.is_sequence({"a": 1, "b": 2})) | |||
| self.assertFalse(nest.is_sequence(set([1, 2]))) | |||
| ones = array_ops.ones([2, 3]) | |||
| self.assertFalse(nest.is_sequence(ones)) | |||
| self.assertFalse(nest.is_sequence(math_ops.tanh(ones))) | |||
| self.assertFalse(nest.is_sequence(np.ones((4, 5)))) | |||
| @parameterized.parameters({"mapping_type": _CustomMapping}, | |||
| {"mapping_type": dict}) | |||
| def testFlattenDictItems(self, mapping_type): | |||
| dictionary = mapping_type({(4, 5, (6, 8)): ("a", "b", ("c", "d"))}) | |||
| flat = {4: "a", 5: "b", 6: "c", 8: "d"} | |||
| self.assertEqual(nest.flatten_dict_items(dictionary), flat) | |||
| with self.assertRaises(TypeError): | |||
| nest.flatten_dict_items(4) | |||
| bad_dictionary = mapping_type({(4, 5, (4, 8)): ("a", "b", ("c", "d"))}) | |||
| with self.assertRaisesRegexp(ValueError, "not unique"): | |||
| nest.flatten_dict_items(bad_dictionary) | |||
| another_bad_dictionary = mapping_type({ | |||
| (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e"))) | |||
| }) | |||
| with self.assertRaisesRegexp( | |||
| ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"): | |||
| nest.flatten_dict_items(another_bad_dictionary) | |||
| # pylint does not correctly recognize these as class names and | |||
| # suggests to use variable style under_score naming. | |||
| # pylint: disable=invalid-name | |||
| Named0ab = collections.namedtuple("named_0", ("a", "b")) | |||
| Named1ab = collections.namedtuple("named_1", ("a", "b")) | |||
| SameNameab = collections.namedtuple("same_name", ("a", "b")) | |||
| SameNameab2 = collections.namedtuple("same_name", ("a", "b")) | |||
| SameNamexy = collections.namedtuple("same_name", ("x", "y")) | |||
| SameName1xy = collections.namedtuple("same_name_1", ("x", "y")) | |||
| SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y")) | |||
| NotSameName = collections.namedtuple("not_same_name", ("a", "b")) | |||
| # pylint: enable=invalid-name | |||
| class SameNamedType1(SameNameab): | |||
| pass | |||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||
| def testAssertSameStructure(self): | |||
| structure1 = (((1, 2), 3), 4, (5, 6)) | |||
| structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) | |||
| structure_different_num_elements = ("spam", "eggs") | |||
| structure_different_nesting = (((1, 2), 3), 4, 5, (6,)) | |||
| nest.assert_same_structure(structure1, structure2) | |||
| nest.assert_same_structure("abc", 1.0) | |||
| nest.assert_same_structure("abc", np.array([0, 1])) | |||
| nest.assert_same_structure("abc", constant_op.constant([0, 1])) | |||
| with self.assertRaisesRegexp( | |||
| ValueError, | |||
| ("The two structures don't have the same nested structure\\.\n\n" | |||
| "First structure:.*?\n\n" | |||
| "Second structure:.*\n\n" | |||
| "More specifically: Substructure " | |||
| r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while ' | |||
| 'substructure "type=str str=spam" is not\n' | |||
| "Entire first structure:\n" | |||
| r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n" | |||
| "Entire second structure:\n" | |||
| r"\(\., \.\)")): | |||
| nest.assert_same_structure(structure1, structure_different_num_elements) | |||
| with self.assertRaisesRegexp( | |||
| ValueError, | |||
| ("The two structures don't have the same nested structure\\.\n\n" | |||
| "First structure:.*?\n\n" | |||
| "Second structure:.*\n\n" | |||
| r'More specifically: Substructure "type=list str=\[0, 1\]" ' | |||
| r'is a sequence, while substructure "type=ndarray str=\[0 1\]" ' | |||
| "is not")): | |||
| nest.assert_same_structure([0, 1], np.array([0, 1])) | |||
| with self.assertRaisesRegexp( | |||
| ValueError, | |||
| ("The two structures don't have the same nested structure\\.\n\n" | |||
| "First structure:.*?\n\n" | |||
| "Second structure:.*\n\n" | |||
| r'More specifically: Substructure "type=list str=\[0, 1\]" ' | |||
| 'is a sequence, while substructure "type=int str=0" ' | |||
| "is not")): | |||
| nest.assert_same_structure(0, [0, 1]) | |||
| self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1]) | |||
| with self.assertRaisesRegexp( | |||
| ValueError, | |||
| ("don't have the same nested structure\\.\n\n" | |||
| "First structure: .*?\n\nSecond structure: ")): | |||
| nest.assert_same_structure(structure1, structure_different_nesting) | |||
| self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), | |||
| NestTest.Named0ab("a", "b")) | |||
| nest.assert_same_structure(NestTest.Named0ab(3, 4), | |||
| NestTest.Named0ab("a", "b")) | |||
| self.assertRaises(TypeError, nest.assert_same_structure, | |||
| NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4)) | |||
| with self.assertRaisesRegexp( | |||
| ValueError, | |||
| ("don't have the same nested structure\\.\n\n" | |||
| "First structure: .*?\n\nSecond structure: ")): | |||
| nest.assert_same_structure(NestTest.Named0ab(3, 4), | |||
| NestTest.Named0ab([3], 4)) | |||
| with self.assertRaisesRegexp( | |||
| ValueError, | |||
| ("don't have the same nested structure\\.\n\n" | |||
| "First structure: .*?\n\nSecond structure: ")): | |||
| nest.assert_same_structure([[3], 4], [3, [4]]) | |||
| structure1_list = [[[1, 2], 3], 4, [5, 6]] | |||
| with self.assertRaisesRegexp(TypeError, | |||
| "don't have the same sequence type"): | |||
| nest.assert_same_structure(structure1, structure1_list) | |||
| nest.assert_same_structure(structure1, structure2, check_types=False) | |||
| nest.assert_same_structure(structure1, structure1_list, check_types=False) | |||
| with self.assertRaisesRegexp(ValueError, | |||
| "don't have the same set of keys"): | |||
| nest.assert_same_structure({"a": 1}, {"b": 1}) | |||
| nest.assert_same_structure(NestTest.SameNameab(0, 1), | |||
| NestTest.SameNameab2(2, 3)) | |||
| # This assertion is expected to pass: two namedtuples with the same | |||
| # name and field names are considered to be identical. | |||
| nest.assert_same_structure( | |||
| NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2), | |||
| NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4)) | |||
| expected_message = "The two structures don't have the same.*" | |||
| with self.assertRaisesRegexp(ValueError, expected_message): | |||
| nest.assert_same_structure( | |||
| NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)), | |||
| NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2)) | |||
| self.assertRaises(TypeError, nest.assert_same_structure, | |||
| NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3)) | |||
| self.assertRaises(TypeError, nest.assert_same_structure, | |||
| NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3)) | |||
| self.assertRaises(TypeError, nest.assert_same_structure, | |||
| NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3)) | |||
| EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name | |||
| def testHeterogeneousComparison(self): | |||
| nest.assert_same_structure({"a": 4}, _CustomMapping(a=3)) | |||
| nest.assert_same_structure(_CustomMapping(b=3), {"b": 4}) | |||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||
| def testMapStructure(self): | |||
| structure1 = (((1, 2), 3), 4, (5, 6)) | |||
| structure2 = (((7, 8), 9), 10, (11, 12)) | |||
| structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1) | |||
| nest.assert_same_structure(structure1, structure1_plus1) | |||
| self.assertAllEqual( | |||
| [2, 3, 4, 5, 6, 7], | |||
| nest.flatten(structure1_plus1)) | |||
| structure1_plus_structure2 = nest.map_structure( | |||
| lambda x, y: x + y, structure1, structure2) | |||
| self.assertEqual( | |||
| (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)), | |||
| structure1_plus_structure2) | |||
| self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4)) | |||
| self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4)) | |||
| # Empty structures | |||
| self.assertEqual((), nest.map_structure(lambda x: x + 1, ())) | |||
| self.assertEqual([], nest.map_structure(lambda x: x + 1, [])) | |||
| self.assertEqual({}, nest.map_structure(lambda x: x + 1, {})) | |||
| self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1, | |||
| NestTest.EmptyNT())) | |||
| # This is checking actual equality of types, empty list != empty tuple | |||
| self.assertNotEqual((), nest.map_structure(lambda x: x + 1, [])) | |||
| with self.assertRaisesRegexp(TypeError, "callable"): | |||
| nest.map_structure("bad", structure1_plus1) | |||
| with self.assertRaisesRegexp(ValueError, "at least one structure"): | |||
| nest.map_structure(lambda x: x) | |||
| with self.assertRaisesRegexp(ValueError, "same number of elements"): | |||
| nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5)) | |||
| with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||
| nest.map_structure(lambda x, y: None, 3, (3,)) | |||
| with self.assertRaisesRegexp(TypeError, "same sequence type"): | |||
| nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5]) | |||
| with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||
| nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5))) | |||
| structure1_list = [[[1, 2], 3], 4, [5, 6]] | |||
| with self.assertRaisesRegexp(TypeError, "same sequence type"): | |||
| nest.map_structure(lambda x, y: None, structure1, structure1_list) | |||
| nest.map_structure(lambda x, y: None, structure1, structure1_list, | |||
| check_types=False) | |||
| with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||
| nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)), | |||
| check_types=False) | |||
| with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): | |||
| nest.map_structure(lambda x: None, structure1, foo="a") | |||
| with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): | |||
| nest.map_structure(lambda x: None, structure1, check_types=False, foo="a") | |||
| ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name | |||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||
| def testMapStructureWithStrings(self): | |||
| inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz")) | |||
| inp_b = NestTest.ABTuple(a=2, b=(1, 3)) | |||
| out = nest.map_structure(lambda string, repeats: string * repeats, | |||
| inp_a, | |||
| inp_b) | |||
| self.assertEqual("foofoo", out.a) | |||
| self.assertEqual("bar", out.b[0]) | |||
| self.assertEqual("bazbazbaz", out.b[1]) | |||
| nt = NestTest.ABTuple(a=("something", "something_else"), | |||
| b="yet another thing") | |||
| rev_nt = nest.map_structure(lambda x: x[::-1], nt) | |||
| # Check the output is the correct structure, and all strings are reversed. | |||
| nest.assert_same_structure(nt, rev_nt) | |||
| self.assertEqual(nt.a[0][::-1], rev_nt.a[0]) | |||
| self.assertEqual(nt.a[1][::-1], rev_nt.a[1]) | |||
| self.assertEqual(nt.b[::-1], rev_nt.b) | |||
| @test_util.run_deprecated_v1 | |||
| def testMapStructureOverPlaceholders(self): | |||
| inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), | |||
| array_ops.placeholder(dtypes.float32, shape=[3, 7])) | |||
| inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), | |||
| array_ops.placeholder(dtypes.float32, shape=[3, 7])) | |||
| output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b) | |||
| nest.assert_same_structure(output, inp_a) | |||
| self.assertShapeEqual(np.zeros((3, 4)), output[0]) | |||
| self.assertShapeEqual(np.zeros((3, 7)), output[1]) | |||
| feed_dict = { | |||
| inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)), | |||
| inp_b: (np.random.randn(3, 4), np.random.randn(3, 7)) | |||
| } | |||
| with self.cached_session() as sess: | |||
| output_np = sess.run(output, feed_dict=feed_dict) | |||
| self.assertAllClose(output_np[0], | |||
| feed_dict[inp_a][0] + feed_dict[inp_b][0]) | |||
| self.assertAllClose(output_np[1], | |||
| feed_dict[inp_a][1] + feed_dict[inp_b][1]) | |||
| def testAssertShallowStructure(self): | |||
| inp_ab = ["a", "b"] | |||
| inp_abc = ["a", "b", "c"] | |||
| expected_message = ( | |||
| "The two structures don't have the same sequence length. Input " | |||
| "structure has length 2, while shallow structure has length 3.") | |||
| with self.assertRaisesRegexp(ValueError, expected_message): | |||
| nest.assert_shallow_structure(inp_abc, inp_ab) | |||
| inp_ab1 = [(1, 1), (2, 2)] | |||
| inp_ab2 = [[1, 1], [2, 2]] | |||
| expected_message = ( | |||
| "The two structures don't have the same sequence type. Input structure " | |||
| "has type <(type|class) 'tuple'>, while shallow structure has type " | |||
| "<(type|class) 'list'>.") | |||
| with self.assertRaisesRegexp(TypeError, expected_message): | |||
| nest.assert_shallow_structure(inp_ab2, inp_ab1) | |||
| nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False) | |||
| inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}} | |||
| inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}} | |||
| expected_message = ( | |||
| r"The two structures don't have the same keys. Input " | |||
| r"structure has keys \['c'\], while shallow structure has " | |||
| r"keys \['d'\].") | |||
| with self.assertRaisesRegexp(ValueError, expected_message): | |||
| nest.assert_shallow_structure(inp_ab2, inp_ab1) | |||
| inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))]) | |||
| inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)]) | |||
| nest.assert_shallow_structure(inp_ab, inp_ba) | |||
| # This assertion is expected to pass: two namedtuples with the same | |||
| # name and field names are considered to be identical. | |||
| inp_shallow = NestTest.SameNameab(1, 2) | |||
| inp_deep = NestTest.SameNameab2(1, [1, 2, 3]) | |||
| nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False) | |||
| nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True) | |||
| def testFlattenUpTo(self): | |||
| # Shallow tree ends at scalar. | |||
| input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] | |||
| shallow_tree = [[True, True], [False, True]] | |||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) | |||
| self.assertEqual(flattened_shallow_tree, [True, True, False, True]) | |||
| # Shallow tree ends at string. | |||
| input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] | |||
| shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] | |||
| input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
| input_tree) | |||
| input_tree_flattened = nest.flatten(input_tree) | |||
| self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
| [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) | |||
| self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) | |||
| # Make sure dicts are correctly flattened, yielding values, not keys. | |||
| input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} | |||
| shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} | |||
| input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
| input_tree) | |||
| self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
| [1, {"c": 2}, 3, (4, 5)]) | |||
| # Namedtuples. | |||
| ab_tuple = NestTest.ABTuple | |||
| input_tree = ab_tuple(a=[0, 1], b=2) | |||
| shallow_tree = ab_tuple(a=0, b=1) | |||
| input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
| input_tree) | |||
| self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
| [[0, 1], 2]) | |||
| # Nested dicts, OrderedDicts and namedtuples. | |||
| input_tree = collections.OrderedDict( | |||
| [("a", ab_tuple(a=[0, {"b": 1}], b=2)), | |||
| ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})]) | |||
| shallow_tree = input_tree | |||
| input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
| input_tree) | |||
| self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) | |||
| shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})]) | |||
| input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
| input_tree) | |||
| self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
| [ab_tuple(a=[0, {"b": 1}], b=2), | |||
| 3, | |||
| collections.OrderedDict([("f", 4)])]) | |||
| shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) | |||
| input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
| input_tree) | |||
| self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
| [ab_tuple(a=[0, {"b": 1}], b=2), | |||
| {"d": 3, "e": collections.OrderedDict([("f", 4)])}]) | |||
| ## Shallow non-list edge-case. | |||
| # Using iterable elements. | |||
| input_tree = ["input_tree"] | |||
| shallow_tree = "shallow_tree" | |||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| self.assertEqual(flattened_input_tree, [input_tree]) | |||
| self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
| input_tree = ["input_tree_0", "input_tree_1"] | |||
| shallow_tree = "shallow_tree" | |||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| self.assertEqual(flattened_input_tree, [input_tree]) | |||
| self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
| # Using non-iterable elements. | |||
| input_tree = [0] | |||
| shallow_tree = 9 | |||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| self.assertEqual(flattened_input_tree, [input_tree]) | |||
| self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
| input_tree = [0, 1] | |||
| shallow_tree = 9 | |||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| self.assertEqual(flattened_input_tree, [input_tree]) | |||
| self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
| ## Both non-list edge-case. | |||
| # Using iterable elements. | |||
| input_tree = "input_tree" | |||
| shallow_tree = "shallow_tree" | |||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| self.assertEqual(flattened_input_tree, [input_tree]) | |||
| self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
| # Using non-iterable elements. | |||
| input_tree = 0 | |||
| shallow_tree = 0 | |||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| self.assertEqual(flattened_input_tree, [input_tree]) | |||
| self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
| ## Input non-list edge-case. | |||
| # Using iterable elements. | |||
| input_tree = "input_tree" | |||
| shallow_tree = ["shallow_tree"] | |||
| expected_message = ("If shallow structure is a sequence, input must also " | |||
| "be a sequence. Input has type: <(type|class) 'str'>.") | |||
| with self.assertRaisesRegexp(TypeError, expected_message): | |||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| self.assertEqual(flattened_shallow_tree, shallow_tree) | |||
| input_tree = "input_tree" | |||
| shallow_tree = ["shallow_tree_9", "shallow_tree_8"] | |||
| with self.assertRaisesRegexp(TypeError, expected_message): | |||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| self.assertEqual(flattened_shallow_tree, shallow_tree) | |||
| # Using non-iterable elements. | |||
| input_tree = 0 | |||
| shallow_tree = [9] | |||
| expected_message = ("If shallow structure is a sequence, input must also " | |||
| "be a sequence. Input has type: <(type|class) 'int'>.") | |||
| with self.assertRaisesRegexp(TypeError, expected_message): | |||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| self.assertEqual(flattened_shallow_tree, shallow_tree) | |||
| input_tree = 0 | |||
| shallow_tree = [9, 8] | |||
| with self.assertRaisesRegexp(TypeError, expected_message): | |||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| self.assertEqual(flattened_shallow_tree, shallow_tree) | |||
| def testMapStructureUpTo(self): | |||
| # Named tuples. | |||
| ab_tuple = collections.namedtuple("ab_tuple", "a, b") | |||
| op_tuple = collections.namedtuple("op_tuple", "add, mul") | |||
| inp_val = ab_tuple(a=2, b=3) | |||
| inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) | |||
| out = nest.map_structure_up_to( | |||
| inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops) | |||
| self.assertEqual(out.a, 6) | |||
| self.assertEqual(out.b, 15) | |||
| # Lists. | |||
| data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] | |||
| name_list = ["evens", ["odds", "primes"]] | |||
| out = nest.map_structure_up_to( | |||
| name_list, lambda name, sec: "first_{}_{}".format(len(sec), name), | |||
| name_list, data_list) | |||
| self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]]) | |||
| # Dicts. | |||
| inp_val = dict(a=2, b=3) | |||
| inp_ops = dict(a=dict(add=1, mul=2), b=dict(add=2, mul=3)) | |||
| out = nest.map_structure_up_to( | |||
| inp_val, | |||
| lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||
| self.assertEqual(out["a"], 6) | |||
| self.assertEqual(out["b"], 15) | |||
| # Non-equal dicts. | |||
| inp_val = dict(a=2, b=3) | |||
| inp_ops = dict(a=dict(add=1, mul=2), c=dict(add=2, mul=3)) | |||
| with self.assertRaisesRegexp(ValueError, "same keys"): | |||
| nest.map_structure_up_to( | |||
| inp_val, | |||
| lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||
| # Dict+custom mapping. | |||
| inp_val = dict(a=2, b=3) | |||
| inp_ops = _CustomMapping(a=dict(add=1, mul=2), b=dict(add=2, mul=3)) | |||
| out = nest.map_structure_up_to( | |||
| inp_val, | |||
| lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||
| self.assertEqual(out["a"], 6) | |||
| self.assertEqual(out["b"], 15) | |||
| # Non-equal dict/mapping. | |||
| inp_val = dict(a=2, b=3) | |||
| inp_ops = _CustomMapping(a=dict(add=1, mul=2), c=dict(add=2, mul=3)) | |||
| with self.assertRaisesRegexp(ValueError, "same keys"): | |||
| nest.map_structure_up_to( | |||
| inp_val, | |||
| lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||
| def testGetTraverseShallowStructure(self): | |||
| scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []] | |||
| scalar_traverse_r = nest.get_traverse_shallow_structure( | |||
| lambda s: not isinstance(s, tuple), | |||
| scalar_traverse_input) | |||
| self.assertEqual(scalar_traverse_r, | |||
| [True, True, False, [True, True], {"a": False}, []]) | |||
| nest.assert_shallow_structure(scalar_traverse_r, | |||
| scalar_traverse_input) | |||
| structure_traverse_input = [(1, [2]), ([1], 2)] | |||
| structure_traverse_r = nest.get_traverse_shallow_structure( | |||
| lambda s: (True, False) if isinstance(s, tuple) else True, | |||
| structure_traverse_input) | |||
| self.assertEqual(structure_traverse_r, | |||
| [(True, False), ([True], False)]) | |||
| nest.assert_shallow_structure(structure_traverse_r, | |||
| structure_traverse_input) | |||
| with self.assertRaisesRegexp(TypeError, "returned structure"): | |||
| nest.get_traverse_shallow_structure(lambda _: [True], 0) | |||
| with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"): | |||
| nest.get_traverse_shallow_structure(lambda _: 1, [1]) | |||
| with self.assertRaisesRegexp( | |||
| TypeError, "didn't return a depth=1 structure of bools"): | |||
| nest.get_traverse_shallow_structure(lambda _: [1], [1]) | |||
| def testYieldFlatStringPaths(self): | |||
| for inputs_expected in ({"inputs": [], "expected": []}, | |||
| {"inputs": 3, "expected": [()]}, | |||
| {"inputs": [3], "expected": [(0,)]}, | |||
| {"inputs": {"a": 3}, "expected": [("a",)]}, | |||
| {"inputs": {"a": {"b": 4}}, | |||
| "expected": [("a", "b")]}, | |||
| {"inputs": [{"a": 2}], "expected": [(0, "a")]}, | |||
| {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]}, | |||
| {"inputs": [{"a": [(23, 42)]}], | |||
| "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]}, | |||
| {"inputs": [{"a": ([23], 42)}], | |||
| "expected": [(0, "a", 0, 0), (0, "a", 1)]}, | |||
| {"inputs": {"a": {"a": 2}, "c": [[[4]]]}, | |||
| "expected": [("a", "a"), ("c", 0, 0, 0)]}, | |||
| {"inputs": {"0": [{"1": 23}]}, | |||
| "expected": [("0", 0, "1")]}): | |||
| inputs = inputs_expected["inputs"] | |||
| expected = inputs_expected["expected"] | |||
| self.assertEqual(list(nest.yield_flat_paths(inputs)), expected) | |||
| def testFlattenWithStringPaths(self): | |||
| for inputs_expected in ( | |||
| {"inputs": [], "expected": []}, | |||
| {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]}, | |||
| {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}): | |||
| inputs = inputs_expected["inputs"] | |||
| expected = inputs_expected["expected"] | |||
| self.assertEqual( | |||
| nest.flatten_with_joined_string_paths(inputs, separator="/"), | |||
| expected) | |||
| # Need a separate test for namedtuple as we can't declare tuple definitions | |||
| # in the @parameterized arguments. | |||
| def testFlattenNamedTuple(self): | |||
| # pylint: disable=invalid-name | |||
| Foo = collections.namedtuple("Foo", ["a", "b"]) | |||
| Bar = collections.namedtuple("Bar", ["c", "d"]) | |||
| # pylint: enable=invalid-name | |||
| test_cases = [ | |||
| (Foo(a=3, b=Bar(c=23, d=42)), | |||
| [("a", 3), ("b/c", 23), ("b/d", 42)]), | |||
| (Foo(a=Bar(c=23, d=42), b=Bar(c=0, d="something")), | |||
| [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]), | |||
| (Bar(c=42, d=43), | |||
| [("c", 42), ("d", 43)]), | |||
| (Bar(c=[42], d=43), | |||
| [("c/0", 42), ("d", 43)]), | |||
| ] | |||
| for inputs, expected in test_cases: | |||
| self.assertEqual( | |||
| list(nest.flatten_with_joined_string_paths(inputs)), expected) | |||
| @parameterized.named_parameters( | |||
| ("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))), | |||
| ("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True, | |||
| {"a": ("a", 4), "b": ("b", 6)}), | |||
| ("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))), | |||
| ("nested", | |||
| {"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True, | |||
| {"a": [("a/0", 10), ("a/1", 12)], | |||
| "b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]})) | |||
| def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected): | |||
| def format_sum(path, *values): | |||
| return (path, sum(values)) | |||
| result = nest.map_structure_with_paths(format_sum, s1, s2, | |||
| check_types=check_types) | |||
| self.assertEqual(expected, result) | |||
| @parameterized.named_parameters( | |||
| ("tuples", (1, 2), (3, 4, 5), ValueError), | |||
| ("dicts", {"a": 1}, {"b": 2}, ValueError), | |||
| ("mixed", (1, 2), [3, 4], TypeError), | |||
| ("nested", | |||
| {"a": [2, 3], "b": [1, 3]}, | |||
| {"b": [5, 6, 7], "a": [8, 9]}, | |||
| ValueError | |||
| )) | |||
| def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type): | |||
| with self.assertRaises(error_type): | |||
| nest.map_structure_with_paths(lambda path, *s: 0, s1, s2) | |||
| class NestBenchmark(test.Benchmark): | |||
| def run_and_report(self, s1, s2, name): | |||
| burn_iter, test_iter = 100, 30000 | |||
| for _ in xrange(burn_iter): | |||
| nest.assert_same_structure(s1, s2) | |||
| t0 = time.time() | |||
| for _ in xrange(test_iter): | |||
| nest.assert_same_structure(s1, s2) | |||
| t1 = time.time() | |||
| self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter, | |||
| name=name) | |||
| def benchmark_assert_structure(self): | |||
| s1 = (((1, 2), 3), 4, (5, 6)) | |||
| s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) | |||
| self.run_and_report(s1, s2, "assert_same_structure_6_elem") | |||
| s1 = (((1, 2), 3), 4, (5, 6)) * 10 | |||
| s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10 | |||
| self.run_and_report(s1, s2, "assert_same_structure_60_elem") | |||
| if __name__ == "__main__": | |||
| test.main() | |||
| @@ -1,249 +0,0 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System.Linq; | |||
| using Tensorflow; | |||
| using Tensorflow.UnitTest; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.UnitTest.ops_test | |||
| { | |||
| /// <summary> | |||
| /// excerpt of tensorflow/python/framework/ops_test.py | |||
| /// </summary> | |||
| [TestClass] | |||
| public class ControlDependenciesTest : GraphModeTestBase | |||
| { | |||
| [TestMethod] | |||
| public void TestBasic() | |||
| { | |||
| var g = tf.Graph().as_default(); | |||
| Tensor a = null, b = null, c = null, d = null, e = null; | |||
| a = constant_op.constant(1.0); | |||
| b = constant_op.constant(1.0); | |||
| tf_with(g.control_dependencies(new[] { a }), x => | |||
| { | |||
| c = constant_op.constant(1.0); | |||
| d = array_ops.identity(b); | |||
| e = array_ops.identity(c); | |||
| }); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(c.op.control_inputs, new[] { a.op })); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(d.op.control_inputs, new[] { a.op })); | |||
| // e should be dominated by c. | |||
| Assert.AreEqual(0, e.op.control_inputs.Length); | |||
| } | |||
| [Ignore("How to port the ConvertibleObj?")] | |||
| [TestMethod] | |||
| public void TestBasicWithConversion() | |||
| { | |||
| var g = tf.Graph().as_default(); | |||
| // Note: _apply_op can be replaced by g.create_op | |||
| var a = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT }); | |||
| // TODO: ConvertibleObj, see original source below | |||
| /* | |||
| def testBasicWithConversion(self): | |||
| g = ops.Graph() | |||
| a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) | |||
| class ConvertibleObj(object): | |||
| def _as_graph_element(self): | |||
| return a | |||
| with g.control_dependencies([ConvertibleObj()]): | |||
| c = _apply_op(g, "FloatOutput", [], [dtypes.float32]) | |||
| self.assertEqual(c.op.control_inputs, [a.op]) | |||
| */ | |||
| } | |||
| [TestMethod] | |||
| public void TestNested() | |||
| { | |||
| var g = tf.Graph().as_default(); | |||
| var a_1 = constant_op.constant(1.0); | |||
| var a_2 = constant_op.constant(3.0); | |||
| var a_3 = constant_op.constant(4.0); | |||
| var a_4 = constant_op.constant(5.0); | |||
| Tensor b_1 = null, b_2 = null; | |||
| tf_with(g.control_dependencies(new[] { a_1, a_2, a_3, a_4 }), ctrl => | |||
| { | |||
| b_1 = constant_op.constant(6.0); | |||
| }); | |||
| tf_with(g.control_dependencies(new[] { a_1 }), ctrl1 => | |||
| { | |||
| tf_with(g.control_dependencies(new[] { a_2 }), ctrl2 => | |||
| { | |||
| tf_with(g.control_dependencies(new[] { a_3 }), ctrl3 => | |||
| { | |||
| tf_with(g.control_dependencies(new[] { a_4 }), ctrl4 => | |||
| { | |||
| b_2 = constant_op.constant(7.0); | |||
| }); | |||
| }); | |||
| }); | |||
| }); | |||
| //var z=tf.add(a_1, tf.multiply(b_2, b_1)); | |||
| //with(g.control_dependencies(new[] {z}), ctrl => | |||
| //{ | |||
| // var z1 = tf.add(a_3, tf.multiply(a_4, a_2)); | |||
| //}); | |||
| //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); | |||
| assertItemsEqual(b_1.op.control_inputs, new[] { a_1.op, a_2.op, a_3.op, a_4.op }); | |||
| assertItemsEqual(b_2.op.control_inputs, b_1.op.control_inputs); | |||
| } | |||
| [TestMethod] | |||
| public void TestClear() | |||
| { | |||
| var g = tf.Graph().as_default(); | |||
| var a_1 = constant_op.constant(1.0); | |||
| var a_2 = constant_op.constant(3.0); | |||
| var a_3 = constant_op.constant(4.0); | |||
| var a_4 = constant_op.constant(5.0); | |||
| Operation b_3_4 = null, b_3 = null, b_none = null, b_1 = null, b_1_2 = null, b_none2 = null; | |||
| tf_with(g.control_dependencies(new[] { a_1 }), ctrl1 => | |||
| { | |||
| tf_with(g.control_dependencies(new[] { a_2 }), ctrl2 => | |||
| { | |||
| tf_with(g.control_dependencies(null), ctrl3 => | |||
| { | |||
| tf_with(g.control_dependencies(new[] { a_3 }), ctrl4 => | |||
| { | |||
| tf_with(g.control_dependencies(new[] { a_4 }), ctrl5 => | |||
| { | |||
| // deps [a_3, a_4] | |||
| b_3_4 = constant_op.constant(7.0); | |||
| }); | |||
| // deps = [a_3] | |||
| b_3 = constant_op.constant(8.0); | |||
| }); | |||
| // deps back to None | |||
| b_none = constant_op.constant(9.0); | |||
| }); | |||
| // deps back to [a_1, a_2] | |||
| b_1_2 = constant_op.constant(10.0); | |||
| }); | |||
| // deps back to [a_1] | |||
| b_1 = constant_op.constant(11.0); | |||
| tf_with(g.control_dependencies(null), ctrl6 => | |||
| { | |||
| // deps are None again | |||
| b_none2 = constant_op.constant(12.0); | |||
| }); | |||
| }); | |||
| // Note assertItemsEqual(given, expected), expected and given parameters should be swapped below | |||
| assertItemsEqual(new[] { a_3.op, a_4.op }, b_3_4.op.control_inputs); | |||
| assertItemsEqual(new[] { a_3.op }, b_3.op.control_inputs); | |||
| assertItemsEqual(new object[0], b_none.op.control_inputs); | |||
| assertItemsEqual(new[] { a_1.op, a_2.op }, b_1_2.op.control_inputs); | |||
| assertItemsEqual(new[] { a_1.op }, b_1.op.control_inputs); | |||
| assertItemsEqual(new object[0], b_none2.op.control_inputs); | |||
| } | |||
| [TestMethod] | |||
| public void TestComplex() | |||
| { | |||
| var g = tf.Graph().as_default(); | |||
| // Usage pattern: | |||
| // * Nodes a_i are constants defined at the outermost scope, and are used | |||
| // as control inputs for the ith nested scope. | |||
| // * Nodes b_i are defined as Mul(a_3, a_4) at each scope. | |||
| // * Nodes c_i are defined as Mul(a_1, b_1) at each scope. | |||
| // * Nodes d_i are defined as Mul(b_i, c_i) at each scope. | |||
| // * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1. | |||
| var a_1 = constant_op.constant(1.0); | |||
| var a_2 = constant_op.constant(2.0); | |||
| var a_3 = constant_op.constant(3.0); | |||
| var a_4 = constant_op.constant(4.0); | |||
| Operation b_1 = null, b_2 = null, b_3 = null, b_4 = null; | |||
| Operation c_1 = null, c_2 = null, c_3 = null, c_4 = null; | |||
| Operation d_1 = null, d_2 = null, d_3 = null, d_4 = null; | |||
| Operation e_1 = null, e_2 = null, e_3 = null, e_4 = null; | |||
| tf_with(g.control_dependencies(new[] { a_1 }), ctrl1 => | |||
| { | |||
| b_1 = tf.multiply(a_3, a_4); | |||
| c_1 = tf.multiply(a_1, b_1.output); | |||
| d_1 = tf.multiply(b_1.output, c_1.output); | |||
| e_1 = constant_op.constant(5.0); | |||
| tf_with(g.control_dependencies(new[] { a_2 }), ctrl2 => | |||
| { | |||
| b_2 = tf.multiply(a_3, a_4); | |||
| c_2 = tf.multiply(a_1, b_1.output); | |||
| d_2 = tf.multiply(b_2.output, c_2.output); | |||
| e_2 = tf.multiply(e_1.output, e_1.output); | |||
| tf_with(g.control_dependencies(new[] { a_3 }), ctrl3 => | |||
| { | |||
| b_3 = tf.multiply(a_3, a_4); | |||
| c_3 = tf.multiply(a_1, b_1.output); | |||
| d_3 = tf.multiply(b_3.output, c_3.output); | |||
| e_3 = tf.multiply(e_2.output, e_2.output); | |||
| tf_with(g.control_dependencies(new[] { a_4 }), ctrl4 => | |||
| { | |||
| b_4 = tf.multiply(a_3, a_4); | |||
| c_4 = tf.multiply(a_1, b_1.output); | |||
| d_4 = tf.multiply(b_4.output, c_4.output); | |||
| e_4 = tf.multiply(e_3.output, e_3.output); | |||
| }); | |||
| }); | |||
| }); | |||
| }); | |||
| // Note assertItemsEqual(given, expected), expected and given parameters should be swapped below | |||
| assertItemsEqual(new[] { a_1.op }, b_1.op.control_inputs); | |||
| assertItemsEqual(new[] { a_1.op, a_2.op }, b_2.op.control_inputs); | |||
| assertItemsEqual(new[] { a_1.op, a_2.op }, b_3.op.control_inputs); | |||
| assertItemsEqual(new[] { a_1.op, a_2.op }, b_4.op.control_inputs); | |||
| assertItemsEqual(new object[0], c_1.op.control_inputs); | |||
| assertItemsEqual(new[] { a_2.op }, c_2.op.control_inputs); | |||
| assertItemsEqual(new[] { a_2.op, a_3.op }, c_3.op.control_inputs); | |||
| assertItemsEqual(new[] { a_2.op, a_3.op, a_4.op }, c_4.op.control_inputs); | |||
| assertItemsEqual(new object[0], d_1.op.control_inputs); | |||
| assertItemsEqual(new object[0], d_2.op.control_inputs); | |||
| assertItemsEqual(new object[0], d_3.op.control_inputs); | |||
| assertItemsEqual(new object[0], d_4.op.control_inputs); | |||
| assertItemsEqual(new[] { a_1.op }, e_1.op.control_inputs); | |||
| assertItemsEqual(new[] { a_2.op }, e_2.op.control_inputs); | |||
| assertItemsEqual(new[] { a_3.op }, e_3.op.control_inputs); | |||
| assertItemsEqual(new[] { a_4.op }, e_4.op.control_inputs); | |||
| } | |||
| [Ignore("Don't know how to create an operation with two outputs")] | |||
| [TestMethod] | |||
| public void TestRepeatedDependency() | |||
| { | |||
| /* | |||
| def testRepeatedDependency(self): | |||
| g = ops.Graph() | |||
| a = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32]) | |||
| a_0, a_1 = a.outputs | |||
| with g.control_dependencies([a_0]): | |||
| b = _apply_op(g, "FloatOutput", [], [dtypes.float32]) | |||
| with g.control_dependencies([a_1]): | |||
| c = _apply_op(g, "FloatOutput", [], [dtypes.float32]) | |||
| self.assertEqual(b.op.control_inputs, [a]) | |||
| self.assertEqual(c.op.control_inputs, [a]) | |||
| */ | |||
| } | |||
| [TestMethod] | |||
| public void TestNoControlDependencyWithDataDependency() | |||
| { | |||
| var g = tf.Graph().as_default(); | |||
| Operation b = null; | |||
| var a = constant_op.constant(100.0); | |||
| tf_with(g.control_dependencies(new[] { a }), ctrl1 => | |||
| { | |||
| b = array_ops.identity(a); | |||
| }); | |||
| Assert.AreEqual(0, b.op.control_inputs.Length); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,222 +0,0 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using System.Linq; | |||
| using Tensorflow; | |||
| using Tensorflow.Operations; | |||
| using Tensorflow.UnitTest; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.UnitTest.ops_test | |||
| { | |||
| /// <summary> | |||
| /// excerpt of tensorflow/python/framework/ops_test.py | |||
| /// # These cases test the private Graph._create_op_from_tf_operation | |||
| /// # method. Arguably we should only test the public APIs that depend on this | |||
| /// # method. However, this logic is complex and tricky, and it can be difficult to | |||
| /// # ascertain if we have adequate coverage (e.g. a graph may run successfully if | |||
| /// # the control flow context isn't set properly, but a more complicated use case | |||
| /// # that might not be obvious to test will fail). Thus we instead explicitly test | |||
| /// # the low-level behavior. | |||
| /// </summary> | |||
| [Ignore] | |||
| [TestClass] | |||
| public class CreateOpFromTfOperationTest : GraphModeTestBase | |||
| { | |||
| [TestMethod] | |||
| public void TestShape() | |||
| { | |||
| using (var g = tf.Graph().as_default()) | |||
| { | |||
| var x = constant_op.constant(new[,] { { 1, 2, 3 }, { 4, 5, 6 } }); | |||
| var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] { x }, new Operation[0]); | |||
| var op = g._create_op_from_tf_operation(c_op); | |||
| Assert.AreEqual("myop", op.name); | |||
| Assert.AreEqual("Identity", op.type); | |||
| Assert.AreEqual(1, len(op.outputs)); | |||
| assertItemsEqual(new[] { 2, 3 }, op.outputs[0].shape); | |||
| } | |||
| } | |||
| [TestMethod] | |||
| public void TestUniqueName() | |||
| { | |||
| var graph = tf.Graph().as_default(); | |||
| //var (c_op,op_desc) = ops._create_c_op(g, ops._NodeDef("Const", "myop"), new Tensor[0], new Operation[0]); | |||
| //var (c_op2, op_desc1) = ops._create_c_op(g, ops._NodeDef("Const", "myop_1"), new Tensor[0], new Operation[0]); | |||
| //var op = g._create_op_from_tf_operation(c_op); | |||
| //var op2 = g._create_op_from_tf_operation(c_op2); | |||
| var op = constant_op.constant(0, name: "myop").op; | |||
| var op2 = constant_op.constant(0, name: "myop_1").op; | |||
| // Create ops with same names as op1 and op2. We expect the new names to be | |||
| // uniquified. | |||
| var op3 = constant_op.constant(0, name: "myop").op; | |||
| var op4 = constant_op.constant(0, name: "myop_1").op; | |||
| self.assertEqual(op.name, "myop"); | |||
| self.assertEqual(op2.name, "myop_1"); | |||
| self.assertEqual(op3.name, "myop_2"); | |||
| self.assertEqual(op4.name, "myop_1_1"); | |||
| } | |||
| [Ignore("need tesnroflow expose UpdateEdge API")] | |||
| [TestMethod] | |||
| public void TestCond() | |||
| { | |||
| var g = tf.Graph().as_default(); | |||
| var x = constant_op.constant(10); | |||
| var true_fn = new Func<Tensor>(() => | |||
| { | |||
| var c_op = ops._create_c_op(g, ops._NodeDef("Identity", "cond/myop"), new[] { x }, new Operation[0]); | |||
| var new_ops = g._add_new_tf_operations(); | |||
| self.assertEqual(len(new_ops), 1); | |||
| return x; | |||
| }); | |||
| control_flow_ops.cond(x < 10, true_fn, () => x); | |||
| var op = g.get_operation_by_name("cond/myop"); | |||
| //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta.txt", as_text:true); | |||
| //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); | |||
| self.assertIsNotNone(op); | |||
| self.assertEqual(op.name, "cond/myop"); | |||
| self.assertEqual(op.type, "Identity"); | |||
| //self.assertEqual(op.outputs, new object[0]); | |||
| var op_input = op.inputs[0].op; | |||
| self.assertEqual(op_input.type, "Switch"); | |||
| self.assertEqual(op_input.inputs[0].name, x.name); | |||
| self.assertEqual(op.graph, g); | |||
| self.assertIsNotNone(op._get_control_flow_context()); | |||
| var cond_text = op._get_control_flow_context() as ControlFlowContext; | |||
| self.assertEqual(cond_text.Name, "cond/cond_text"); | |||
| } | |||
| [Ignore("Todo: Port")] | |||
| [TestMethod] | |||
| public void TestWhileLoop() | |||
| { | |||
| var graph = tf.Graph().as_default(); | |||
| Operation x = null; | |||
| x = constant_op.constant(42); | |||
| var body = new Func<int, int>(i => | |||
| { | |||
| ops._create_c_op(ops.get_default_graph(), ops._NodeDef("Identity", "myloop/myop"), new[] { x.output }, | |||
| new Operation[0]); | |||
| var new_ops = graph._add_new_tf_operations(); | |||
| self.assertEqual(len(new_ops), 1); | |||
| return i; | |||
| }); | |||
| // TODO: port control_flow_ops.while_loop | |||
| //control_flow_ops.while_loop( i => i < 10, body, new int[]{0}, name = "myloop"); | |||
| var op = graph.get_operation_by_name("myloop/myop"); | |||
| self.assertIsNotNone(op); | |||
| self.assertEqual(op.name, "myloop/myop"); | |||
| self.assertEqual(op.type, "Identity"); | |||
| self.assertEqual(op.outputs.Length, 0); | |||
| var op_input = op.inputs[0].op; | |||
| self.assertEqual(op_input.type, "Enter"); | |||
| self.assertItemsEqual(op_input.inputs.OfType<Operation>().ToArray(), new[] { x }); | |||
| self.assertEqual(op.graph, graph); | |||
| self.assertIsNotNone(op._get_control_flow_context()); | |||
| self.assertEqual(((ControlFlowContext)op._get_control_flow_context()).Name, "myloop/while_context"); | |||
| /* | |||
| @test_util.run_v1_only("b/120545219") | |||
| def testWhileLoop(self): | |||
| g = ops.Graph() | |||
| with g.as_default(): | |||
| x = test_ops.int_output() | |||
| def body(i): | |||
| ops._create_c_op(ops.get_default_graph(), | |||
| ops._NodeDef("IntInput", "myloop/myop"), [x], []) | |||
| new_ops = g._add_new_tf_operations() | |||
| self.assertEqual(len(new_ops), 1) | |||
| return i | |||
| control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") | |||
| op = g.get_operation_by_name("myloop/myop") | |||
| self.assertIsNotNone(op) | |||
| self.assertEqual(op.name, "myloop/myop") | |||
| self.assertEqual(op.type, "IntInput") | |||
| self.assertEqual(op.outputs, []) | |||
| op_input = op.inputs[0].op | |||
| self.assertEqual(op_input.type, "Enter") | |||
| self.assertEqual(list(op_input.inputs), [x]) | |||
| self.assertEqual(op.graph, g) | |||
| # pylint: disable=protected-access | |||
| self.assertIsNotNone(op._get_control_flow_context()) | |||
| self.assertEqual(op._get_control_flow_context().name, | |||
| "myloop/while_context") | |||
| # pylint: enable=protected-access | |||
| */ | |||
| } | |||
| [Ignore("Todo: Port")] | |||
| [TestMethod] | |||
| public void TestWhileLoopWithInternalControlDep() | |||
| { | |||
| /* | |||
| @test_util.run_v1_only("b/120545219") | |||
| def testWhileLoopWithInternalControlDep(self): | |||
| g = ops.Graph() | |||
| with g.as_default(): | |||
| x = test_ops.int_output() | |||
| def body(i): | |||
| c = constant_op.constant(1.0, name="c") | |||
| ops._create_c_op(ops.get_default_graph(), | |||
| ops._NodeDef("IntInput", "myloop/myop"), [x], []) | |||
| with ops.control_dependencies([c]): | |||
| new_ops = g._add_new_tf_operations() | |||
| self.assertEqual(len(new_ops), 1) | |||
| return i | |||
| control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") | |||
| op = g.get_operation_by_name("myloop/myop") | |||
| self.assertIsNotNone(op) | |||
| c = g.get_operation_by_name("myloop/c") | |||
| self.assertIsNotNone(c) | |||
| # Internal control dep is preserved | |||
| self.assertEqual(op.control_inputs, [c]) | |||
| */ | |||
| } | |||
| [Ignore("Todo: Port")] | |||
| [TestMethod] | |||
| public void TestWhileLoopWithExternalControlDep() | |||
| { | |||
| /* | |||
| @test_util.run_v1_only("b/120545219") | |||
| def testWhileLoopWithExternalControlDep(self): | |||
| g = ops.Graph() | |||
| with g.as_default(): | |||
| x = test_ops.int_output() | |||
| c = constant_op.constant(1.0) | |||
| def body(i): | |||
| ops._create_c_op(ops.get_default_graph(), | |||
| ops._NodeDef("IntInput", "myloop/myop"), [x], []) | |||
| with ops.control_dependencies([c]): | |||
| new_ops = g._add_new_tf_operations() | |||
| self.assertEqual(len(new_ops), 1) | |||
| return i | |||
| control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") | |||
| op = g.get_operation_by_name("myloop/myop") | |||
| self.assertIsNotNone(op) | |||
| # External control dep is removed and replaced with internal control dep | |||
| self.assertNotEqual(op.control_inputs[0], c.op) | |||
| self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context()) | |||
| */ | |||
| } | |||
| } | |||
| } | |||
| @@ -1,196 +0,0 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using Tensorflow; | |||
| using Tensorflow.UnitTest; | |||
| namespace TensorFlowNET.UnitTest.ops_test | |||
| { | |||
| /// <summary> | |||
| /// excerpt of tensorflow/python/framework/ops_test.py | |||
| /// </summary> | |||
| [TestClass] | |||
| public class GraphTest : GraphModeTestBase | |||
| { | |||
| [TestInitialize] | |||
| public void SetUp() | |||
| { | |||
| ops.reset_default_graph(); | |||
| } | |||
| [TestCleanup] | |||
| public void TearDown() | |||
| { | |||
| ops.reset_default_graph(); | |||
| } | |||
| private void _AssertDefault(Graph expected) | |||
| { | |||
| Assert.AreSame(ops.get_default_graph(), expected); | |||
| } | |||
| [Ignore("Todo: Port")] | |||
| [TestMethod] | |||
| public void testResetDefaultGraphNesting() | |||
| { | |||
| /* | |||
| def testResetDefaultGraphNesting(self): | |||
| g0 = ops.Graph() | |||
| with self.assertRaises(AssertionError): | |||
| with g0.as_default(): | |||
| ops.reset_default_graph() | |||
| */ | |||
| } | |||
| [Ignore("Todo: Port")] | |||
| [TestMethod] | |||
| public void testGraphContextManagerCancelsEager() | |||
| { | |||
| /* | |||
| def testGraphContextManagerCancelsEager(self): | |||
| with context.eager_mode(): | |||
| with ops.Graph().as_default(): | |||
| self.assertFalse(context.executing_eagerly()) | |||
| */ | |||
| } | |||
| [Ignore("Todo: Port")] | |||
| [TestMethod] | |||
| public void testGraphContextManager() | |||
| { | |||
| /* | |||
| def testGraphContextManager(self): | |||
| g0 = ops.Graph() | |||
| with g0.as_default() as g1: | |||
| self.assertIs(g0, g1) | |||
| */ | |||
| } | |||
| [Ignore("Todo: Port")] | |||
| [TestMethod] | |||
| public void testDefaultGraph() | |||
| { | |||
| /* | |||
| def testDefaultGraph(self): | |||
| orig = ops.get_default_graph() | |||
| self._AssertDefault(orig) | |||
| g0 = ops.Graph() | |||
| self._AssertDefault(orig) | |||
| context_manager_0 = g0.as_default() | |||
| self._AssertDefault(orig) | |||
| with context_manager_0 as g0: | |||
| self._AssertDefault(g0) | |||
| with ops.Graph().as_default() as g1: | |||
| self._AssertDefault(g1) | |||
| self._AssertDefault(g0) | |||
| self._AssertDefault(orig) | |||
| */ | |||
| } | |||
| [Ignore("Todo: Port")] | |||
| [TestMethod] | |||
| public void testPreventFeeding() | |||
| { | |||
| /* | |||
| def testPreventFeeding(self): | |||
| g = ops.Graph() | |||
| a = constant_op.constant(2.0) | |||
| self.assertTrue(g.is_feedable(a)) | |||
| g.prevent_feeding(a) | |||
| self.assertFalse(g.is_feedable(a)) | |||
| */ | |||
| } | |||
| [Ignore("Todo: Port")] | |||
| [TestMethod] | |||
| public void testAsGraphElementConversions() | |||
| { | |||
| /* | |||
| def testAsGraphElementConversions(self): | |||
| class ConvertibleObj(object): | |||
| def _as_graph_element(self): | |||
| return "FloatOutput:0" | |||
| class NonConvertibleObj(object): | |||
| pass | |||
| g = ops.Graph() | |||
| a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) | |||
| self.assertEqual(a, g.as_graph_element(ConvertibleObj())) | |||
| with self.assertRaises(TypeError): | |||
| g.as_graph_element(NonConvertibleObj()) | |||
| */ | |||
| } | |||
| [Ignore("Todo: Port")] | |||
| [TestMethod] | |||
| public void testGarbageCollected() | |||
| { | |||
| /* | |||
| # Regression test against creating custom __del__ functions in classes | |||
| # involved in cyclic references, e.g. Graph and Operation. (Python won't gc | |||
| # cycles that require calling a __del__ method, because the __del__ method can | |||
| # theoretically increase the object's refcount to "save" it from gc, and any | |||
| # already-deleted objects in the cycle would have be to restored.) | |||
| def testGarbageCollected(self): | |||
| # Create a graph we can delete and a weak reference to monitor if it's gc'd | |||
| g = ops.Graph() | |||
| g_ref = weakref.ref(g) | |||
| # Create some ops | |||
| with g.as_default(): | |||
| a = constant_op.constant(2.0) | |||
| b = constant_op.constant(3.0) | |||
| c = math_ops.add(a, b) | |||
| # Create a session we can delete | |||
| with session.Session(graph=g) as sess: | |||
| self.evaluate(c) | |||
| # Delete all references and trigger gc | |||
| del g | |||
| del a | |||
| del b | |||
| del c | |||
| del sess | |||
| gc.collect() | |||
| self.assertIsNone(g_ref()) | |||
| */ | |||
| } | |||
| [Ignore("Todo: Port")] | |||
| [TestMethod] | |||
| public void testRunnableAfterInvalidShape() | |||
| { | |||
| /* | |||
| def testRunnableAfterInvalidShape(self): | |||
| with ops.Graph().as_default(): | |||
| with self.assertRaises(ValueError): | |||
| math_ops.add([1, 2], [1, 2, 3]) | |||
| a = constant_op.constant(1) | |||
| with session.Session() as sess: | |||
| self.evaluate(a) | |||
| */ | |||
| } | |||
| [Ignore("Todo: Port")] | |||
| [TestMethod] | |||
| public void testRunnableAfterInvalidShapeWithKernelLabelMap() | |||
| { | |||
| /* | |||
| def testRunnableAfterInvalidShapeWithKernelLabelMap(self): | |||
| g = ops.Graph() | |||
| with g.as_default(): | |||
| with g._kernel_label_map({"KernelLabelRequired": "overload_1"}): | |||
| with self.assertRaises(ValueError): | |||
| test_ops.kernel_label_required(1) | |||
| a = constant_op.constant(1) | |||
| with session.Session() as sess: | |||
| self.evaluate(a) | |||
| */ | |||
| } | |||
| } | |||
| } | |||
| @@ -1,26 +0,0 @@ | |||
| | |||
| import tensorflow as tf | |||
| # Create some variables. | |||
| v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer) | |||
| v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer) | |||
| inc_v1 = v1.assign(v1+1) | |||
| dec_v2 = v2.assign(v2-1) | |||
| # Add an op to initialize the variables. | |||
| init_op = tf.global_variables_initializer() | |||
| # Add ops to save and restore all the variables. | |||
| saver = tf.train.Saver() | |||
| # Later, launch the model, initialize the variables, do some work, and save the | |||
| # variables to disk. | |||
| with tf.Session() as sess: | |||
| sess.run(init_op) | |||
| # Do some work with the model. | |||
| inc_v1.op.run() | |||
| dec_v2.op.run() | |||
| # Save the variables to disk. | |||
| save_path = saver.save(sess, "/tmp/model.ckpt") | |||
| print("Model saved in path: %s" % save_path) | |||