Browse Source

Adjust unit test project.

tags/keras_v0.3.0
Oceania2018 4 years ago
parent
commit
6fe6057ff4
66 changed files with 553 additions and 7089 deletions
  1. +32
    -6
      TensorFlow.NET.sln
  2. +1
    -0
      src/TensorFlowNET.Console/Tensorflow.Console.csproj
  3. +2
    -9
      src/TensorFlowNET.Core/APIs/tf.graph.cs
  4. +1
    -13
      src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs
  5. +10
    -2
      src/TensorFlowNET.Core/Contexts/Context.cs
  6. +10
    -0
      src/TensorFlowNET.Core/Contexts/ContextSwitchStack.cs
  7. +4
    -2
      src/TensorFlowNET.Core/Data/MapDataset.cs
  8. +16
    -5
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  9. +0
    -1
      src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs
  10. +2
    -1
      src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs
  11. +4
    -2
      src/TensorFlowNET.Core/Graphs/AutoGraph.cs
  12. +4
    -1
      src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs
  13. +18
    -38
      src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs
  14. +3
    -4
      src/TensorFlowNET.Core/Graphs/FuncGraph.cs
  15. +7
    -1
      src/TensorFlowNET.Core/Graphs/Graph.cs
  16. +1
    -1
      src/TensorFlowNET.Core/Graphs/SubGraphUtility.cs
  17. +5
    -2
      src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
  18. +20
    -0
      src/TensorFlowNET.Core/ops.cs
  19. +9
    -9
      src/TensorFlowNET.Core/ops.threading.cs
  20. +3
    -3
      src/TensorFlowNET.Keras/BackendImpl.cs
  21. +1
    -1
      src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs
  22. +2
    -1
      src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs
  23. +1
    -0
      src/TensorFlowNET.Recommenders/Tensorflow.Recommenders.csproj
  24. +5
    -0
      src/TensorFlowNET.Text/Tensorflow.Text.csproj
  25. +1
    -2
      test/TensorFlowNET.Native.UnitTest/CApiAttributesTestcs.cs
  26. +1
    -2
      test/TensorFlowNET.Native.UnitTest/CApiColocationTest.cs
  27. +2
    -3
      test/TensorFlowNET.Native.UnitTest/CApiFunctionTest.cs
  28. +1
    -3
      test/TensorFlowNET.Native.UnitTest/CApiGradientsTest.cs
  29. +2
    -4
      test/TensorFlowNET.Native.UnitTest/CApiTest.cs
  30. +1
    -2
      test/TensorFlowNET.Native.UnitTest/CSession.cs
  31. +1
    -2
      test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.Context.cs
  32. +1
    -2
      test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.Execute_MatMul_CPU.cs
  33. +1
    -2
      test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs
  34. +1
    -2
      test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs
  35. +1
    -2
      test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.TensorHandle.cs
  36. +1
    -2
      test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.TensorHandleDevices.cs
  37. +1
    -2
      test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.Variables.cs
  38. +1
    -2
      test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.cs
  39. +2
    -3
      test/TensorFlowNET.Native.UnitTest/Eager/GradientEagerTest.cs
  40. +1
    -2
      test/TensorFlowNET.Native.UnitTest/GraphBuildTest.cs
  41. +1
    -3
      test/TensorFlowNET.Native.UnitTest/GraphTest.cs
  42. +74
    -0
      test/TensorFlowNET.Native.UnitTest/Sessions/SessionTest.cs
  43. +36
    -0
      test/TensorFlowNET.Native.UnitTest/Tensorflow.Native.UnitTest.csproj
  44. +204
    -0
      test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs
  45. +1
    -3
      test/TensorFlowNET.Native.UnitTest/c_test_util.cs
  46. +3
    -70
      test/TensorFlowNET.UnitTest/Basics/SessionTest.cs
  47. +3
    -220
      test/TensorFlowNET.UnitTest/Basics/TensorTest.cs
  48. +30
    -0
      test/TensorFlowNET.UnitTest/EagerModeTestBase.cs
  49. +2
    -2
      test/TensorFlowNET.UnitTest/EnforcedSinglethreadingTests.cs
  50. +2
    -0
      test/TensorFlowNET.UnitTest/GraphModeTestBase.cs
  51. +1
    -1
      test/TensorFlowNET.UnitTest/ManagedAPI/ActivationFunctionTest.cs
  52. +1
    -1
      test/TensorFlowNET.UnitTest/ManagedAPI/BitwiseApiTest.cs
  53. +1
    -1
      test/TensorFlowNET.UnitTest/ManagedAPI/FunctionApiTest.cs
  54. +1
    -1
      test/TensorFlowNET.UnitTest/ManagedAPI/MathApiTest.cs
  55. +0
    -35
      test/TensorFlowNET.UnitTest/ManagedAPI/TFNetApiTest.cs
  56. +0
    -86
      test/TensorFlowNET.UnitTest/ManagedAPI/ZeroFractionTest.cs
  57. +1
    -0
      test/TensorFlowNET.UnitTest/MultithreadingTests.cs
  58. +12
    -3
      test/TensorFlowNET.UnitTest/Utilities/TestHelper.cs
  59. +0
    -1059
      test/TensorFlowNET.UnitTest/control_flow_ops_test/control_flow_ops_test.py
  60. +0
    -875
      test/TensorFlowNET.UnitTest/nest_test/NestTest.cs
  61. +0
    -883
      test/TensorFlowNET.UnitTest/nest_test/nest_test.py
  62. +0
    -249
      test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs
  63. +0
    -222
      test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs
  64. +0
    -196
      test/TensorFlowNET.UnitTest/ops_test/GraphTest.cs
  65. +0
    -3014
      test/TensorFlowNET.UnitTest/ops_test/ops_test_r1.13.py
  66. +0
    -26
      test/TensorFlowNET.UnitTest/python/train_saver.py

+ 32
- 6
TensorFlow.NET.sln View File

@@ -17,6 +17,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Text", "src\Tens
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Recommenders", "src\TensorFlowNET.Recommenders\Tensorflow.Recommenders.csproj", "{F17AAECB-960A-4E18-A270-BAD776F0E55B}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Native.UnitTest", "test\TensorFlowNET.Native.UnitTest\Tensorflow.Native.UnitTest.csproj", "{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -107,8 +109,8 @@ Global
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x86.Build.0 = Release|Any CPU
{03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|Any CPU.Build.0 = Debug|Any CPU
{03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x64.ActiveCfg = Debug|Any CPU
{03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x64.Build.0 = Debug|Any CPU
{03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x64.ActiveCfg = Debug|x64
{03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x64.Build.0 = Debug|x64
{03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x86.ActiveCfg = Debug|Any CPU
{03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x86.Build.0 = Debug|Any CPU
{03F06299-3F4B-4449-A709-3A647657BC0C}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU
@@ -155,8 +157,8 @@ Global
{49D71826-C03D-4FA7-9BAC-22C1327E65CF}.Release|x86.Build.0 = Release|Any CPU
{1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|Any CPU.Build.0 = Debug|Any CPU
{1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x64.ActiveCfg = Debug|Any CPU
{1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x64.Build.0 = Debug|Any CPU
{1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x64.ActiveCfg = Debug|x64
{1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x64.Build.0 = Debug|x64
{1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x86.ActiveCfg = Debug|Any CPU
{1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x86.Build.0 = Debug|Any CPU
{1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU
@@ -179,8 +181,8 @@ Global
{1AB8108D-4FFE-4A16-88E7-328EAF686370}.Release|x86.Build.0 = Release|Any CPU
{F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|Any CPU.Build.0 = Debug|Any CPU
{F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x64.ActiveCfg = Debug|Any CPU
{F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x64.Build.0 = Debug|Any CPU
{F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x64.ActiveCfg = Debug|x64
{F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x64.Build.0 = Debug|x64
{F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x86.ActiveCfg = Debug|Any CPU
{F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x86.Build.0 = Debug|Any CPU
{F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU
@@ -201,6 +203,30 @@ Global
{F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|x64.Build.0 = Release|Any CPU
{F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|x86.ActiveCfg = Release|Any CPU
{F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|x86.Build.0 = Release|Any CPU
{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|Any CPU.Build.0 = Debug|Any CPU
{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|x64.ActiveCfg = Debug|x64
{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|x64.Build.0 = Debug|x64
{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|x86.ActiveCfg = Debug|Any CPU
{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|x86.Build.0 = Debug|Any CPU
{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU
{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU
{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU
{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|x64.Build.0 = Debug|Any CPU
{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|x86.ActiveCfg = Debug|Any CPU
{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|x86.Build.0 = Debug|Any CPU
{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|Any CPU.ActiveCfg = Debug|Any CPU
{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|Any CPU.Build.0 = Debug|Any CPU
{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|x64.ActiveCfg = Debug|Any CPU
{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|x64.Build.0 = Debug|Any CPU
{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|x86.ActiveCfg = Debug|Any CPU
{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|x86.Build.0 = Debug|Any CPU
{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|Any CPU.ActiveCfg = Release|Any CPU
{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|Any CPU.Build.0 = Release|Any CPU
{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x64.ActiveCfg = Release|Any CPU
{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x64.Build.0 = Release|Any CPU
{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x86.ActiveCfg = Release|Any CPU
{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x86.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE


+ 1
- 0
src/TensorFlowNET.Console/Tensorflow.Console.csproj View File

@@ -5,6 +5,7 @@
<TargetFramework>netcoreapp3.1</TargetFramework>
<RootNamespace>Tensorflow</RootNamespace>
<AssemblyName>Tensorflow</AssemblyName>
<Platforms>AnyCPU;x64</Platforms>
</PropertyGroup>

<ItemGroup>


+ 2
- 9
src/TensorFlowNET.Core/APIs/tf.graph.cs View File

@@ -28,17 +28,10 @@ namespace Tensorflow
=> ops.reset_default_graph();

public Graph get_default_graph()
{
return ops.get_default_graph();
}
=> ops.get_default_graph();

/// <summary>
/// Equivalent to <see cref="get_default_graph"/> but does not create a new graph if it there is none.
/// </summary>
public Graph peak_default_graph()
{
return ops.default_graph_stack.peak_controller();
}
=> ops.peak_default_graph();

/// <summary>
/// Creates a new graph.


+ 1
- 13
src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs View File

@@ -37,19 +37,7 @@ namespace Tensorflow.Contexts
if (shouldRunInEager)
return eagerAction();
else
{
if (executing_eagerly())
{
graph_mode();
var result = graphAction();
restore_mode();
return result;
}
else
{
return graphAction();
}
}
return graphAction();
}

// [DebuggerStepThrough]


+ 10
- 2
src/TensorFlowNET.Core/Contexts/Context.cs View File

@@ -80,9 +80,12 @@ namespace Tensorflow.Contexts
/// Checks whether the current thread has eager execution enabled.
/// </summary>
/// <returns></returns>
[DebuggerStepThrough]
// [DebuggerStepThrough]
public bool executing_eagerly()
{
if(context_switches.Count() == 0)
tf.enable_eager_execution();
return context_switches.Current().EagerMode;
}

@@ -103,11 +106,16 @@ namespace Tensorflow.Contexts
public void restore_mode()
{
context_switches.Pop();
tf.get_default_graph();
}

public void reset_context()
{
c_api.TFE_ContextClearCaches(_handle);
ops.reset_uid();
ops.reset_default_graph();
context_switches.Clear();
if (_handle != null)
c_api.TFE_ContextClearCaches(_handle);
}

public void Dispose()


+ 10
- 0
src/TensorFlowNET.Core/Contexts/ContextSwitchStack.cs View File

@@ -40,11 +40,21 @@ namespace Tensorflow.Contexts
});
}

public void Clear()
{
stack.Clear();
}

public void Pop()
{
stack.Pop();
}

public int Count()
{
return stack.Count;
}

public ContextSwitch Current()
{
return stack.Peek();


+ 4
- 2
src/TensorFlowNET.Core/Data/MapDataset.cs View File

@@ -15,11 +15,13 @@ namespace Tensorflow
bool preserve_cardinality = false,
bool use_legacy_function = false) : base(input_dataset)
{
using var func = new ConcreteFunction($"autograph_{map_func.Method.Name}");
using var func = new ConcreteFunction($"{map_func.Method.Name}_{Guid.NewGuid()}");
func.Enter();
var input = tf.placeholder(input_dataset.element_spec[0].dtype);
var output = map_func(input);
func.ToGraph(input, output);
func.Exit();

structure = func.OutputStructure;

variant_tensor = ops.map_dataset(input_dataset.variant_tensor,


+ 16
- 5
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -34,7 +34,6 @@ namespace Tensorflow.Functions
public ConcreteFunction(string name)
{
func_graph = new FuncGraph(name);
func_graph.as_default();
}

public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs = null)
@@ -46,7 +45,7 @@ namespace Tensorflow.Functions

public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype)
{
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";
string func_name = $"{func.Method.Name}_{Guid.NewGuid()}";

// IntPtr func_handle;
using var graph = new FuncGraph(func_name);
@@ -59,11 +58,12 @@ namespace Tensorflow.Functions
new[] { input },
new[] { output },
null);
graph.Exit();
}

public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype)
{
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";
string func_name = $"{func.Method.Name}_{Guid.NewGuid()}";

// IntPtr func_handle;
using var graph = new FuncGraph(func_name);
@@ -79,12 +79,13 @@ namespace Tensorflow.Functions
new[] { input },
new[] { output.variant_tensor },
null);
graph.Exit();
}

public ConcreteFunction(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> func,
TF_DataType[] dtypes, TensorShape[] shapes)
{
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";
string func_name = $"{func.Method.Name}_{Guid.NewGuid()}";

// IntPtr func_handle;
using var graph = new FuncGraph(func_name);
@@ -103,6 +104,7 @@ namespace Tensorflow.Functions
new[] { input1, input2, input3 },
new[] { outputs.Item1, outputs.Item2 },
null);
graph.Exit();
}

public void ToGraph(Tensors inputs, Tensors outputs)
@@ -112,10 +114,19 @@ namespace Tensorflow.Functions
inputs,
outputs,
null);

OutputStructure = outputs.Select(x => x.ToTensorSpec()).ToArray();
}

public void Enter()
{
func_graph.as_default();
}

public void Exit()
{
func_graph.Exit();
}

public Tensors Invoke(Tensors inputs)
{
var forward_backward = SelectForwardAndBackwardFunctions(inputs, 1, tf.Context.executing_eagerly());


+ 0
- 1
src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs View File

@@ -26,7 +26,6 @@ namespace Tensorflow.Functions
var output_names = new string[0];

_func_graph = new FuncGraph(graph, name, attrs);
_func_graph.as_default();
_func_graph.ToGraph(operations, inputs, outputs, output_names);
}



+ 2
- 1
src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs View File

@@ -84,7 +84,7 @@ namespace Tensorflow.Functions
}

var gradients_wrt_outputs = new List<Tensor>();
var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{ops.uid()}");
var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}");
backwards_graph.as_default();
foreach (var output in trainable_outputs)
gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape));
@@ -101,6 +101,7 @@ namespace Tensorflow.Functions
if (!_func_graph.Outputs.Contains(capture))
_func_graph.Outputs.Add(capture);
}
backwards_graph.Exit();

var forward_function_name = $"{_FORWARD_PREFIX}_{ops.uid()}";
var backward_function_attr = new Dictionary<string, string>();


+ 4
- 2
src/TensorFlowNET.Core/Graphs/AutoGraph.cs View File

@@ -8,7 +8,7 @@ namespace Tensorflow.Graphs
{
public Func<Tensor, Tensor> to_graph(Func<Tensor, Tensor> func)
{
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";
string func_name = $"{func.Method.Name}_{Guid.NewGuid()}";

// IntPtr func_handle;
using (var graph = new FuncGraph(func_name))
@@ -22,6 +22,7 @@ namespace Tensorflow.Graphs
new[] { input },
new[] { output },
null);
graph.Exit();
}

@@ -39,7 +40,7 @@ namespace Tensorflow.Graphs

public Func<Tensor, Tensor, Tensor> to_graph(Func<Tensor, Tensor, Tensor> func)
{
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";
string func_name = $"{func.Method.Name}_{Guid.NewGuid()}";

// IntPtr func_handle;
using (var graph = new FuncGraph(func_name))
@@ -54,6 +55,7 @@ namespace Tensorflow.Graphs
new[] { input1, input2 },
new[] { output },
null);
graph.Exit();
}
return (Tensor a, Tensor b) =>


+ 4
- 1
src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs View File

@@ -18,7 +18,7 @@ namespace Tensorflow.Graphs

public override void OnEntry(MethodExecutionArgs args)
{
func_name = $"autograph_{args.Instance.GetHashCode()}.{args.Method.Name}";
func_name = $"{args.Method.Name}_{Guid.NewGuid()}";

if (functions.ContainsKey(func_name))
{
@@ -34,6 +34,7 @@ namespace Tensorflow.Graphs
// make function as an Operation by autograph
// need to restore mode when exits
function = new ConcreteFunction(func_name);
function.Enter();

// convert to Tensors
if (args.Arguments[0] is Tensors inputs)
@@ -68,6 +69,8 @@ namespace Tensorflow.Graphs
}
else
function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), args.ReturnValue as Tensor);
function.Exit();

// cache function.
function.ReturnType = args.ReturnValue.GetType();


+ 18
- 38
src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs View File

@@ -25,63 +25,43 @@ namespace Tensorflow
/// </summary>
public class DefaultGraphStack
{
private readonly List<StackModel> _stack = new List<StackModel>();
private readonly Stack<Graph> _stack = new Stack<Graph>();
Graph _global_default_graph;

public void set_controller(Graph @default)
public Graph get_default()
{
if (!_stack.Exists(x => x.Graph == @default))
_stack.Add(new StackModel { Graph = @default, IsDefault = true });
if (_stack.Count > 0)
return _stack.Peek();
else if (_global_default_graph != null)
return _global_default_graph;
else
_global_default_graph = new Graph();

foreach (var s in _stack)
s.IsDefault = s.Graph == @default;
return _global_default_graph;
}

public Graph get_controller()
public Graph get_controller(Graph g)
{
if (_stack.Count == 0 || _stack.Count(x => x.IsDefault) == 0)
_stack.Add(new StackModel { Graph = tf.Graph(), IsDefault = true });
for (var i = _stack.Count - 1; i >= 0; i--)
{
var x = _stack[i];
if (x.IsDefault)
return x.Graph;
}

throw new TensorflowException("Unable to find a default graph");
_stack.Push(g);
return g;
}

public Graph peak_controller()
{
if (_stack.Count == 0 || _stack.Count(x => x.IsDefault) == 0)
if (_stack.Count == 0)
return null;
for (var i = _stack.Count - 1; i >= 0; i--)
{
var x = _stack[i];
if (x.IsDefault)
return x.Graph;
}

return null;
return _stack.Peek();
}

public bool remove(Graph g)
public void pop()
{
if (_stack.Count == 0)
return false;

var sm = _stack.Find(model => model.Graph == g);
return sm != null && _stack.Remove(sm);
_stack.Pop();
}

public void reset()
{
_stack.Clear();
}

private class StackModel
{
public Graph Graph { get; set; }
public bool IsDefault { get; set; }
_global_default_graph = null;
}
}
}

+ 3
- 4
src/TensorFlowNET.Core/Graphs/FuncGraph.cs View File

@@ -94,8 +94,6 @@ namespace Tensorflow.Graphs
// mark_as_return
Outputs = outputs;// .Select(x => array_ops.identity(x)).ToArray();

tf.Context.restore_mode();

return func_handle;
}

@@ -247,9 +245,10 @@ namespace Tensorflow.Graphs
return this;
}

protected override void DisposeManagedResources()
public override void Exit()
{
base.DisposeManagedResources();
tf.Context.restore_mode();
ops.pop_graph();
}
}
}

+ 7
- 1
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -146,6 +146,7 @@ namespace Tensorflow

/// <summary>
/// Returns a context manager that makes this `Graph` the default graph.
/// Must call Exit() to pop graph
/// </summary>
/// <returns></returns>
public virtual Graph as_default()
@@ -487,7 +488,7 @@ namespace Tensorflow

protected override void DisposeManagedResources()
{
ops.default_graph_stack.remove(this);
}

protected override void DisposeUnmanagedResources(IntPtr handle)
@@ -529,6 +530,11 @@ namespace Tensorflow
return new TensorShape(dims.Select(x => (int)x).ToArray());
}

public virtual void Exit()
{
ops.pop_graph();
}

string debugString = string.Empty;
public override string ToString()
{


+ 1
- 1
src/TensorFlowNET.Core/Graphs/SubGraphUtility.cs View File

@@ -95,7 +95,7 @@ namespace Tensorflow.Graphs
_copy_non_source(op, graph, op_map, base_graph);
}

tf.Context.restore_mode();
graph.Exit();

return op_map;
}


+ 5
- 2
src/TensorFlowNET.Core/Operations/OpDefLibrary.cs View File

@@ -30,7 +30,7 @@ namespace Tensorflow

public Operation _apply_op_helper(string op_type_name, string name = null, Dictionary<string, object> keywords = null)
{
var g = ops.get_default_graph();
var g = ops._get_graph_from_inputs(keywords == null ? new object[0] : keywords.Values.ToArray());
var op_def = g.GetOpDef(op_type_name);

// Default name if not specified.
@@ -59,7 +59,8 @@ namespace Tensorflow
var input_types = new List<TF_DataType>();
object values = null;

return tf_with(ops.name_scope(name), scope =>
g.as_default();
var ret_op = tf_with(ops.name_scope(name), scope =>
{
var inferred_from = new Dictionary<string, object>();
var base_types = new List<TF_DataType>();
@@ -249,6 +250,8 @@ namespace Tensorflow

return op;
});
g.Exit();
return ret_op;
}

private void _MaybeColocateWith(ITensorOrOperation[] inputs)


+ 20
- 0
src/TensorFlowNET.Core/ops.cs View File

@@ -78,6 +78,21 @@ namespace Tensorflow
return get_default_graph().get_collection_ref<T>(key);
}

public static Graph _get_graph_from_inputs(params object[] op_input_list)
{
var current_default_graph = get_default_graph();
if (current_default_graph.building_function)
return current_default_graph;

Graph graph = null;
foreach (var op_input in op_input_list)
{
if (op_input is Tensor op_input_tensor)
graph = graph ?? op_input_tensor.graph;
}
return graph ?? current_default_graph;
}

public static Graph _get_graph_from_inputs(Tensors op_input_list)
=> _get_graph_from_inputs(op_input_list: op_input_list, graph: null);

@@ -337,6 +352,11 @@ namespace Tensorflow
return Interlocked.Increment(ref uid_number);
}

public static void reset_uid()
{
uid_number = -1;
}

public static void colocate_with(bool ignore_existing = false)
{
_colocate_with_for_gradient(null, null, ignore_existing);


+ 9
- 9
src/TensorFlowNET.Core/ops.threading.cs View File

@@ -118,16 +118,10 @@ namespace Tensorflow
/// </summary>
/// <returns></returns>
public static Graph get_default_graph()
{
//return _default_graph_stack.get_default()
return default_graph_stack.get_controller();
}
=> default_graph_stack.get_default();

public static Graph set_default_graph(Graph graph)
{
default_graph_stack.set_controller(graph);
return default_graph_stack.get_controller();
}
public static Graph set_default_graph(Graph g)
=> default_graph_stack.get_controller(g);

/// <summary>
/// Clears the default graph stack and resets the global default graph.
@@ -147,5 +141,11 @@ namespace Tensorflow
// "exit the nesting and create a new graph.");
default_graph_stack.reset();
}

public static Graph peak_default_graph()
=> default_graph_stack.peak_controller();

public static void pop_graph()
=> default_graph_stack.pop();
}
}

+ 3
- 3
src/TensorFlowNET.Keras/BackendImpl.cs View File

@@ -115,10 +115,10 @@ namespace Tensorflow.Keras
public void reset_uids() => PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<string, int>>();
public void clear_session()
{
ops.reset_default_graph();
tf.Context.reset_context();
reset_uids();
ops.set_default_session(tf.Session(ops.get_default_graph()));
var phase = tf.placeholder_with_default(false, new int[] { }, name: "keras_learning_phase");
// var phase = tf.placeholder_with_default(false, new int[] { }, name: "keras_learning_phase");
_GRAPH_LEARNING_PHASES = new Dictionary<Graph, GraphLearningPhase>();
_GRAPH_LEARNING_PHASES[tf.get_default_graph()] = 0;
}
@@ -185,7 +185,7 @@ namespace Tensorflow.Keras
return tensor_util.constant_value(outputs);
var source_graph = outputs.graph;
using var exec_graph = _scratch_graph();
var exec_graph = _scratch_graph();
var global_graph = get_graph();
if (source_graph == global_graph && exec_graph != global_graph)
{


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs View File

@@ -46,7 +46,7 @@ namespace Tensorflow.Keras.Engine
_set_mask_metadata(inputs, outputs, null);
});

tf.Context.restore_mode();
graph.Exit();

return outputs;
}


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs View File

@@ -81,8 +81,9 @@ namespace Tensorflow.Keras.Layers
sparse: args.Sparse,
ragged: args.Ragged);

graph.Exit();

isPlaceholder = true;
tf.Context.restore_mode();
}

// Create an input node to add to self.outbound_node


+ 1
- 0
src/TensorFlowNET.Recommenders/Tensorflow.Recommenders.csproj View File

@@ -5,6 +5,7 @@
<Version>0.0.1</Version>
<Description>TensorFlow Recommenders is a library for building recommender system models using TensorFlow.</Description>
<PackageLicenseFile>LICENSE</PackageLicenseFile>
<Platforms>AnyCPU;x64</Platforms>
</PropertyGroup>

<ItemGroup>


+ 5
- 0
src/TensorFlowNET.Text/Tensorflow.Text.csproj View File

@@ -7,12 +7,17 @@
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
<Version>0.0.1</Version>
<PackageLicenseFile>LICENSE</PackageLicenseFile>
<Platforms>AnyCPU;x64</Platforms>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
<DefineConstants>DEBUG;TRACE</DefineConstants>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<DefineConstants>DEBUG;TRACE</DefineConstants>
</PropertyGroup>

<ItemGroup>
<None Include="..\..\LICENSE">
<Pack>True</Pack>


test/TensorFlowNET.UnitTest/NativeAPI/CApiAttributesTestcs.cs → test/TensorFlowNET.Native.UnitTest/CApiAttributesTestcs.cs View File

@@ -1,8 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;

namespace TensorFlowNET.UnitTest.NativeAPI
namespace Tensorflow.Native.UnitTest
{
/// <summary>
/// tensorflow\c\c_api_test.cc

test/TensorFlowNET.UnitTest/NativeAPI/CApiColocationTest.cs → test/TensorFlowNET.Native.UnitTest/CApiColocationTest.cs View File

@@ -1,9 +1,8 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Runtime.InteropServices;
using Tensorflow;

namespace TensorFlowNET.UnitTest.NativeAPI
namespace Tensorflow.Native.UnitTest
{
/// <summary>
/// tensorflow\c\c_api_test.cc

test/TensorFlowNET.UnitTest/NativeAPI/CApiFunctionTest.cs → test/TensorFlowNET.Native.UnitTest/CApiFunctionTest.cs View File

@@ -2,10 +2,9 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow;
using static TensorFlowNET.UnitTest.c_test_util;
using static Tensorflow.Native.UnitTest.c_test_util;

namespace TensorFlowNET.UnitTest.NativeAPI
namespace Tensorflow.Native.UnitTest
{
/// <summary>
/// tensorflow\c\c_api_function_test.cc

test/TensorFlowNET.UnitTest/NativeAPI/CApiGradientsTest.cs → test/TensorFlowNET.Native.UnitTest/CApiGradientsTest.cs View File

@@ -1,11 +1,9 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using System;
using Tensorflow;
using Tensorflow.Util;
using Buffer = Tensorflow.Buffer;

namespace TensorFlowNET.UnitTest.NativeAPI
namespace Tensorflow.Native.UnitTest
{
/// <summary>
/// tensorflow\c\c_api_test.cc

test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs → test/TensorFlowNET.Native.UnitTest/CApiTest.cs View File

@@ -1,13 +1,11 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;
using Tensorflow.Device;
using Tensorflow.Eager;
using Tensorflow.UnitTest;

namespace TensorFlowNET.UnitTest
namespace Tensorflow.Native.UnitTest
{
public class CApiTest : GraphModeTestBase
public class CApiTest
{
protected static readonly TF_Code TF_OK = TF_Code.TF_OK;
protected static readonly TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT;

test/TensorFlowNET.UnitTest/NativeAPI/CSession.cs → test/TensorFlowNET.Native.UnitTest/CSession.cs View File

@@ -1,10 +1,9 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow;
using Tensorflow.Util;

namespace TensorFlowNET.UnitTest
namespace Tensorflow.Native.UnitTest
{
/// <summary>
/// tensorflow\c\c_test_util.cc

test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs → test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.Context.cs View File

@@ -1,9 +1,8 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
using Tensorflow.Device;
using Tensorflow.Eager;

namespace TensorFlowNET.UnitTest.NativeAPI
namespace Tensorflow.Native.UnitTest.Eager
{
public partial class CApiEagerTest
{

test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs → test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.Execute_MatMul_CPU.cs View File

@@ -1,10 +1,9 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;
using Tensorflow.Eager;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.NativeAPI
namespace Tensorflow.Native.UnitTest.Eager
{
public partial class CApiEagerTest
{

test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs → test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs View File

@@ -1,8 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
using Tensorflow.Eager;

namespace TensorFlowNET.UnitTest.NativeAPI
namespace Tensorflow.Native.UnitTest.Eager
{
public partial class CApiEagerTest
{

test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs → test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs View File

@@ -1,8 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
using Tensorflow.Eager;

namespace TensorFlowNET.UnitTest.NativeAPI
namespace Tensorflow.Native.UnitTest.Eager
{
public partial class CApiEagerTest
{

test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandle.cs → test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.TensorHandle.cs View File

@@ -1,8 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.NativeAPI
namespace Tensorflow.Native.UnitTest.Eager
{
public partial class CApiEagerTest
{

test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs → test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.TensorHandleDevices.cs View File

@@ -1,8 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
using Tensorflow.Eager;

namespace TensorFlowNET.UnitTest.NativeAPI
namespace Tensorflow.Native.UnitTest.Eager
{
public partial class CApiEagerTest
{

test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs → test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.Variables.cs View File

@@ -1,9 +1,8 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
using Tensorflow.Eager;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.NativeAPI
namespace Tensorflow.Native.UnitTest.Eager
{
public partial class CApiEagerTest
{

test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs → test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.cs View File

@@ -1,10 +1,9 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;
using Tensorflow.Eager;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.NativeAPI
namespace Tensorflow.Native.UnitTest.Eager
{
/// <summary>
/// tensorflow\c\eager\c_api_test.cc

test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs → test/TensorFlowNET.Native.UnitTest/Eager/GradientEagerTest.cs View File

@@ -1,13 +1,12 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Linq;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.Gradient
namespace Tensorflow.Native.UnitTest.Eager
{
[TestClass]
public class GradientEagerTest : PythonTest
public class GradientEagerTest
{
[TestMethod]
public void ConstantSquare()

test/TensorFlowNET.UnitTest/NativeAPI/GraphBuildTest.cs → test/TensorFlowNET.Native.UnitTest/GraphBuildTest.cs View File

@@ -1,8 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.NativeAPI
namespace Tensorflow.Native.UnitTest
{
[TestClass]
public class GraphBuildTest : CApiTest

test/TensorFlowNET.UnitTest/NativeAPI/GraphTest.cs → test/TensorFlowNET.Native.UnitTest/GraphTest.cs View File

@@ -1,10 +1,8 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;
using static Tensorflow.Binding;
using Buffer = Tensorflow.Buffer;

namespace TensorFlowNET.UnitTest.NativeAPI
namespace Tensorflow.Native.UnitTest
{
[TestClass]
public class GraphTest : CApiTest

+ 74
- 0
test/TensorFlowNET.Native.UnitTest/Sessions/SessionTest.cs View File

@@ -0,0 +1,74 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;

namespace Tensorflow.Native.UnitTest.Sessions
{
[TestClass, Ignore]
public class SessionTest : CApiTest
{
/// <summary>
/// tensorflow\c\c_api_test.cc
/// `TEST(CAPI, Session)`
/// </summary>
[TestMethod]
public void Session()
{
using var s = new Status();
using var graph = new Graph();

// Make a placeholder operation.
var feed = c_test_util.Placeholder(graph, s);

// Make a constant operation with the scalar "2".
var two = c_test_util.ScalarConst(2, graph, s);

// Add operation.
var add = c_test_util.Add(feed, two, graph, s);

var csession = new CSession(graph, s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);

// Run the graph.
var inputs = new Dictionary<Operation, Tensor>();
inputs.Add(feed, new Tensor(3));
csession.SetInputs(inputs);

var outputs = new TF_Output[] { new TF_Output(add, 0) };
csession.SetOutputs(outputs);

csession.Run(s);
Tensor outTensor = csession.output_tensor(0);
EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
EXPECT_EQ(0, outTensor.NDims);
ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
var output_contents = outTensor.ToArray<int>();
EXPECT_EQ(3 + 2, output_contents[0]);

// Add another operation to the graph.
var neg = c_test_util.Neg(add, graph, s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);

// Run up to the new operation.
inputs = new Dictionary<Operation, Tensor>();
inputs.Add(feed, new Tensor(7));
csession.SetInputs(inputs);
outputs = new TF_Output[] { new TF_Output(neg, 0) };
csession.SetOutputs(outputs);
csession.Run(s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);

outTensor = csession.output_tensor(0);
ASSERT_TRUE(outTensor != IntPtr.Zero);
EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
EXPECT_EQ(0, outTensor.NDims); // scalar
ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
output_contents = outTensor.ToArray<int>();
EXPECT_EQ(-(7 + 2), output_contents[0]);

// Clean up
csession.CloseAndDelete(s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);
}
}
}

+ 36
- 0
test/TensorFlowNET.Native.UnitTest/Tensorflow.Native.UnitTest.csproj View File

@@ -0,0 +1,36 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>netcoreapp3.1</TargetFramework>

<IsPackable>false</IsPackable>

<Platforms>AnyCPU;x64</Platforms>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<DefineConstants>DEBUG;TRACE</DefineConstants>
<PlatformTarget>x64</PlatformTarget>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<DefineConstants>DEBUG;TRACE</DefineConstants>
<PlatformTarget>x64</PlatformTarget>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="FluentAssertions" Version="5.10.3" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.8.3" />
<PackageReference Include="MSTest.TestAdapter" Version="2.1.2" />
<PackageReference Include="MSTest.TestFramework" Version="2.1.2" />
<PackageReference Include="coverlet.collector" Version="1.3.0" />
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.1" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\src\TensorFlowNET.Core\Tensorflow.Binding.csproj" />
</ItemGroup>

</Project>

+ 204
- 0
test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs View File

@@ -0,0 +1,204 @@
using FluentAssertions;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using System;
using System.Linq;
using System.Runtime.InteropServices;
using static Tensorflow.Binding;

namespace Tensorflow.Native.UnitTest.Tensors
{
[TestClass]
public class TensorTest : CApiTest
{
[TestMethod]
public unsafe void TensorFromFixed()
{
var array = new float[1000];
var span = new Span<float>(array, 100, 500);
fixed (float* ptr = &MemoryMarshal.GetReference(span))
{
using (var t = new Tensor((IntPtr)ptr, new long[] { span.Length }, tf.float32, 4 * span.Length))
{
Assert.IsFalse(t.IsDisposed);
Assert.AreEqual(2000, (int)t.bytesize);
}
}

fixed (float* ptr = &array[0])
{
using (var t = new Tensor((IntPtr)ptr, new long[] { array.Length }, tf.float32, 4 * array.Length))
{
Assert.IsFalse(t.IsDisposed);
Assert.AreEqual(4000, (int)t.bytesize);
}
}
}

[TestMethod]
public void TensorFromArray()
{
var array = new float[1000];
using (var t = new Tensor(array, new long[] { array.Length }, tf.float32))
{
Assert.IsFalse(t.IsDisposed);
Assert.AreEqual(1000 * sizeof(float), (int)t.bytesize);
}

using (var t = new Tensor(new float[] { 1 }, new long[] { 1 }, tf.float32))
{
Assert.IsFalse(t.IsDisposed);
Assert.AreEqual(1 * sizeof(float), (int)t.bytesize);
}

using (var t = new Tensor(new float[] { 1 }, null, tf.float32))
{
Assert.IsFalse(t.IsDisposed);
Assert.AreEqual(1 * sizeof(float), (int)t.bytesize);
t.shape.Should().BeEmpty();
}
}

[TestMethod]
public void AllocateTensor()
{
ulong num_bytes = 6 * sizeof(float);
long[] dims = { 2, 3 };
Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes);
EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype);
EXPECT_EQ(2, t.NDims);
EXPECT_EQ((int)dims[0], t.shape[0]);
EXPECT_EQ(num_bytes, t.bytesize);
t.Dispose();
}


/// <summary>
/// Port from c_api_test.cc
/// `TEST(CAPI, MaybeMove)`
/// </summary>
[TestMethod, Ignore]
public void MaybeMove()
{
NDArray nd = np.array(2, 3);
Tensor t = new Tensor(nd);
Tensor o = t.MaybeMove();
ASSERT_TRUE(o == IntPtr.Zero); // It is unsafe to move memory TF might not own.
t.Dispose();
}

/// <summary>
/// Port from c_api_test.cc
/// `TEST(CAPI, Tensor)`
/// </summary>
[TestMethod]
public void Tensor()
{
var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3);

var tensor = new Tensor(nd);
var array = tensor.ToArray<float>();

EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT);
EXPECT_EQ(tensor.rank, nd.ndim);
EXPECT_EQ((int)tensor.shape[0], nd.shape[0]);
EXPECT_EQ((int)tensor.shape[1], nd.shape[1]);
EXPECT_EQ(tensor.bytesize, (ulong)nd.size * sizeof(float));
Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), new float[] { 1, 2, 3, 4, 5, 6 }));
}

/// <summary>
/// Port from tensorflow\c\c_api_test.cc
/// `TEST(CAPI, SetShape)`
/// </summary>
[TestMethod]
public void SetShape()
{
var s = new Status();
var graph = new Graph().as_default();

var feed = c_test_util.Placeholder(graph, s);
var feed_out_0 = new TF_Output(feed, 0);

// Fetch the shape, it should be completely unknown.
int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle);

Assert.IsTrue(s.Code == TF_Code.TF_OK);
EXPECT_EQ(-1, num_dims);

// Set the shape to be unknown, expect no change.
c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s.Handle);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle);
EXPECT_EQ(-1, num_dims);

// Set the shape to be 2 x Unknown
long[] dims = { 2, -1 };
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s.Handle);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle);
EXPECT_EQ(2, num_dims);

// Get the dimension vector appropriately.
var returned_dims = new long[dims.Length];
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims));

// Set to a new valid shape: [2, 3]
dims[1] = 3;
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s.Handle);
Assert.IsTrue(s.Code == TF_Code.TF_OK);

// Fetch and see that the new value is returned.
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims));

// Try to set 'unknown' with unknown rank on the shape and see that
// it doesn't change.
c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s.Handle);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
EXPECT_EQ(2, num_dims);
EXPECT_EQ(2, (int)returned_dims[0]);
EXPECT_EQ(3, (int)returned_dims[1]);

// Try to set 'unknown' with same rank on the shape and see that
// it doesn't change.
dims[0] = -1;
dims[1] = -1;
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s.Handle);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
EXPECT_EQ(2, num_dims);
EXPECT_EQ(2, (int)returned_dims[0]);
EXPECT_EQ(3, (int)returned_dims[1]);

// Try to fetch a shape with the wrong num_dims
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s.Handle);
Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT);

// Try to set an invalid shape (cannot change 2x3 to a 2x5).
dims[1] = 5;
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s.Handle);
Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT);

// Test for a scalar.
var three = c_test_util.ScalarConst(3, graph, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
var three_out_0 = new TF_Output(three, 0);

num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s.Handle);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
EXPECT_EQ(0, num_dims);
c_api.TF_GraphGetTensorShape(graph, feed_out_0, dims, num_dims, s.Handle);
Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT);

graph.Exit();
s.Dispose();
}
}
}

test/TensorFlowNET.UnitTest/NativeAPI/c_test_util.cs → test/TensorFlowNET.Native.UnitTest/c_test_util.cs View File

@@ -1,10 +1,8 @@
using System;
using System.Diagnostics.CodeAnalysis;
using Tensorflow;
using Tensorflow.Util;
using Buffer = Tensorflow.Buffer;

namespace TensorFlowNET.UnitTest
namespace Tensorflow.Native.UnitTest
{
/// <summary>
/// Port from `tensorflow\c\c_test_util.cc`

+ 3
- 70
test/TensorFlowNET.UnitTest/Basics/SessionTest.cs View File

@@ -8,78 +8,11 @@ using Tensorflow;
using Tensorflow.Util;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.NativeAPI
namespace TensorFlowNET.UnitTest
{
[TestClass]
public class SessionTest : CApiTest
[TestClass, Ignore]
public class SessionTest
{
/// <summary>
/// tensorflow\c\c_api_test.cc
/// `TEST(CAPI, Session)`
/// </summary>
[TestMethod, Ignore]
public void Session()
{
lock (Locks.ProcessWide)
{
var s = new Status();
var graph = new Graph().as_default();

// Make a placeholder operation.
var feed = c_test_util.Placeholder(graph, s);

// Make a constant operation with the scalar "2".
var two = c_test_util.ScalarConst(2, graph, s);

// Add operation.
var add = c_test_util.Add(feed, two, graph, s);

var csession = new CSession(graph, s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);

// Run the graph.
var inputs = new Dictionary<Operation, Tensor>();
inputs.Add(feed, new Tensor(3));
csession.SetInputs(inputs);

var outputs = new TF_Output[] { new TF_Output(add, 0) };
csession.SetOutputs(outputs);

csession.Run(s);
Tensor outTensor = csession.output_tensor(0);
EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
EXPECT_EQ(0, outTensor.NDims);
ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
var output_contents = outTensor.ToArray<int>();
EXPECT_EQ(3 + 2, output_contents[0]);

// Add another operation to the graph.
var neg = c_test_util.Neg(add, graph, s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);

// Run up to the new operation.
inputs = new Dictionary<Operation, Tensor>();
inputs.Add(feed, new Tensor(7));
csession.SetInputs(inputs);
outputs = new TF_Output[] { new TF_Output(neg, 0) };
csession.SetOutputs(outputs);
csession.Run(s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);

outTensor = csession.output_tensor(0);
ASSERT_TRUE(outTensor != IntPtr.Zero);
EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
EXPECT_EQ(0, outTensor.NDims); // scalar
ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
output_contents = outTensor.ToArray<int>();
EXPECT_EQ(-(7 + 2), output_contents[0]);

// Clean up
csession.CloseAndDelete(s);
ASSERT_EQ(TF_Code.TF_OK, s.Code);
}
}

[TestMethod]
public void EvalTensor()
{


+ 3
- 220
test/TensorFlowNET.UnitTest/Basics/TensorTest.cs View File

@@ -7,201 +7,11 @@ using System.Runtime.InteropServices;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.NativeAPI
namespace TensorFlowNET.UnitTest
{
[TestClass]
public class TensorTest : CApiTest
[TestClass, Ignore]
public class TensorTest
{
[TestMethod]
public unsafe void TensorFromFixed()
{
var array = new float[1000];
var span = new Span<float>(array, 100, 500);
fixed (float* ptr = &MemoryMarshal.GetReference(span))
{
using (var t = new Tensor((IntPtr)ptr, new long[] { span.Length }, tf.float32, 4 * span.Length))
{
Assert.IsFalse(t.IsDisposed);
Assert.AreEqual(2000, (int)t.bytesize);
}
}

fixed (float* ptr = &array[0])
{
using (var t = new Tensor((IntPtr)ptr, new long[] { array.Length }, tf.float32, 4 * array.Length))
{
Assert.IsFalse(t.IsDisposed);
Assert.AreEqual(4000, (int)t.bytesize);
}
}
}

[TestMethod]
public unsafe void TensorFromArray()
{
var array = new float[1000];
using (var t = new Tensor(array, new long[] { array.Length }, tf.float32))
{
Assert.IsFalse(t.IsDisposed);
Assert.AreEqual(1000 * sizeof(float), (int)t.bytesize);
}

using (var t = new Tensor(new float[] { 1 }, new long[] { 1 }, tf.float32))
{
Assert.IsFalse(t.IsDisposed);
Assert.AreEqual(1 * sizeof(float), (int)t.bytesize);
}

using (var t = new Tensor(new float[] { 1 }, null, tf.float32))
{
Assert.IsFalse(t.IsDisposed);
Assert.AreEqual(1 * sizeof(float), (int)t.bytesize);
t.shape.Should().BeEmpty();
}
}

[TestMethod]
public void AllocateTensor()
{
ulong num_bytes = 6 * sizeof(float);
long[] dims = { 2, 3 };
Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes);
EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype);
EXPECT_EQ(2, t.NDims);
EXPECT_EQ((int)dims[0], t.shape[0]);
EXPECT_EQ(num_bytes, t.bytesize);
t.Dispose();
}


/// <summary>
/// Port from c_api_test.cc
/// `TEST(CAPI, MaybeMove)`
/// </summary>
[TestMethod, Ignore]
public void MaybeMove()
{
NDArray nd = np.array(2, 3);
Tensor t = new Tensor(nd);
Tensor o = t.MaybeMove();
ASSERT_TRUE(o == IntPtr.Zero); // It is unsafe to move memory TF might not own.
t.Dispose();
}

/// <summary>
/// Port from c_api_test.cc
/// `TEST(CAPI, Tensor)`
/// </summary>
[TestMethod]
public void Tensor()
{
var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3);

var tensor = new Tensor(nd);
var array = tensor.ToArray<float>();

EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT);
EXPECT_EQ(tensor.rank, nd.ndim);
EXPECT_EQ((int)tensor.shape[0], nd.shape[0]);
EXPECT_EQ((int)tensor.shape[1], nd.shape[1]);
EXPECT_EQ(tensor.bytesize, (ulong)nd.size * sizeof(float));
Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), new float[] { 1, 2, 3, 4, 5, 6 }));
}

/// <summary>
/// Port from tensorflow\c\c_api_test.cc
/// `TEST(CAPI, SetShape)`
/// </summary>
[TestMethod]
public void SetShape()
{
var s = new Status();
var graph = new Graph().as_default();

var feed = c_test_util.Placeholder(graph, s);
var feed_out_0 = new TF_Output(feed, 0);

// Fetch the shape, it should be completely unknown.
int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle);

Assert.IsTrue(s.Code == TF_Code.TF_OK);
EXPECT_EQ(-1, num_dims);

// Set the shape to be unknown, expect no change.
c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s.Handle);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle);
EXPECT_EQ(-1, num_dims);

// Set the shape to be 2 x Unknown
long[] dims = { 2, -1 };
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s.Handle);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle);
EXPECT_EQ(2, num_dims);

// Get the dimension vector appropriately.
var returned_dims = new long[dims.Length];
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims));

// Set to a new valid shape: [2, 3]
dims[1] = 3;
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s.Handle);
Assert.IsTrue(s.Code == TF_Code.TF_OK);

// Fetch and see that the new value is returned.
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims));

// Try to set 'unknown' with unknown rank on the shape and see that
// it doesn't change.
c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s.Handle);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
EXPECT_EQ(2, num_dims);
EXPECT_EQ(2, (int)returned_dims[0]);
EXPECT_EQ(3, (int)returned_dims[1]);

// Try to set 'unknown' with same rank on the shape and see that
// it doesn't change.
dims[0] = -1;
dims[1] = -1;
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s.Handle);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
EXPECT_EQ(2, num_dims);
EXPECT_EQ(2, (int)returned_dims[0]);
EXPECT_EQ(3, (int)returned_dims[1]);

// Try to fetch a shape with the wrong num_dims
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s.Handle);
Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT);

// Try to set an invalid shape (cannot change 2x3 to a 2x5).
dims[1] = 5;
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s.Handle);
Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT);

// Test for a scalar.
var three = c_test_util.ScalarConst(3, graph, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
var three_out_0 = new TF_Output(three, 0);

num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s.Handle);
Assert.IsTrue(s.Code == TF_Code.TF_OK);
EXPECT_EQ(0, num_dims);
c_api.TF_GraphGetTensorShape(graph, feed_out_0, null, num_dims, s.Handle);
//Assert.IsTrue(s.Code == TF_Code.TF_OK);

// graph.Dispose();
s.Dispose();
}

[TestMethod]
public void sparse_to_dense()
{
@@ -271,32 +81,5 @@ namespace TensorFlowNET.UnitTest.NativeAPI
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, masked.ToArray<int>()));
}
}

/// <summary>
/// Creates a tensor from an image of 256x256x3 and resizes it to 100x100x3
/// </summary>
[TestMethod]
public unsafe void tensor_resize()
{
tf.enable_eager_execution();

var imageArray = new float[256 * 256 * 3];

using var newSize = tf.convert_to_tensor(new int[] { 100, 100 });

using (var t = tf.constant(imageArray, tf.float32, (1, 256, 256, 3)))
{
Assert.IsFalse(t.IsDisposed);
Assert.AreEqual(256 * 256 * 3 * sizeof(float), (int)t.bytesize);

using var resized = tf.image.resize_bilinear(t, newSize);
EXPECT_EQ(resized.shape[0], 1);
EXPECT_EQ(resized.shape[1], 100);
EXPECT_EQ(resized.shape[2], 100);
EXPECT_EQ(resized.shape[3], 3);
}

tf.compat.v1.disable_eager_execution();
}
}
}

+ 30
- 0
test/TensorFlowNET.UnitTest/EagerModeTestBase.cs View File

@@ -1,4 +1,5 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest
@@ -10,11 +11,40 @@ namespace TensorFlowNET.UnitTest
{
if (!tf.executing_eagerly())
tf.enable_eager_execution();
tf.Context.ensure_initialized();
}

[TestCleanup]
public void TestClean()
{
}

public bool Equal(float[] f1, float[] f2)
{
bool ret = false;
var tolerance = .000001f;
for (var i = 0; i < f1.Length; i++)
{
ret = Math.Abs(f1[i] - f2[i]) <= tolerance;
if (!ret)
break;
}

return ret;
}

public bool Equal(double[] d1, double[] d2)
{
bool ret = false;
var tolerance = .000000000000001f;
for (var i = 0; i < d1.Length; i++)
{
ret = Math.Abs(d1[i] - d2[i]) <= tolerance;
if (!ret)
break;
}

return ret;
}
}
}

+ 2
- 2
test/TensorFlowNET.UnitTest/EnforcedSinglethreadingTests.cs View File

@@ -6,10 +6,10 @@ using System.Threading;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.NativeAPI
namespace TensorFlowNET.UnitTest
{
[TestClass]
public class EnforcedSinglethreadingTests : CApiTest
public class EnforcedSinglethreadingTests
{
private static readonly object _singlethreadLocker = new object();



+ 2
- 0
test/TensorFlowNET.UnitTest/GraphModeTestBase.cs View File

@@ -1,6 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using TensorFlowNET.UnitTest;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace Tensorflow.UnitTest
{
@@ -15,6 +16,7 @@ namespace Tensorflow.UnitTest
[TestCleanup]
public void TestClean()
{
keras.backend.clear_session();
tf.enable_eager_execution();
}
}


+ 1
- 1
test/TensorFlowNET.UnitTest/ManagedAPI/ActivationFunctionTest.cs View File

@@ -5,7 +5,7 @@ using static Tensorflow.Binding;
namespace TensorFlowNET.UnitTest.nn_test
{
[TestClass]
public class ActivationFunctionTest : TFNetApiTest
public class ActivationFunctionTest : EagerModeTestBase
{
// A constant vector of size 6
Tensor a = tf.constant(new float[] { 1.0f, -0.5f, 3.4f, -2.1f, 0.0f, -6.5f });


+ 1
- 1
test/TensorFlowNET.UnitTest/ManagedAPI/BitwiseApiTest.cs View File

@@ -6,7 +6,7 @@ using static Tensorflow.Binding;
namespace TensorFlowNET.UnitTest.ManagedAPI
{
[TestClass]
public class BitwiseApiTest : TFNetApiTest
public class BitwiseApiTest : EagerModeTestBase
{
[TestInitialize]
public void Init()


+ 1
- 1
test/TensorFlowNET.UnitTest/ManagedAPI/FunctionApiTest.cs View File

@@ -7,7 +7,7 @@ using static Tensorflow.Binding;
namespace TensorFlowNET.UnitTest.ManagedAPI
{
[TestClass]
public class FunctionApiTest : TFNetApiTest
public class FunctionApiTest : EagerModeTestBase
{
Tensor Min(Tensor a, Tensor b)
{


+ 1
- 1
test/TensorFlowNET.UnitTest/ManagedAPI/MathApiTest.cs View File

@@ -6,7 +6,7 @@ using static Tensorflow.Binding;
namespace TensorFlowNET.UnitTest.ManagedAPI
{
[TestClass]
public class MathApiTest : TFNetApiTest
public class MathApiTest : EagerModeTestBase
{
// A constant vector of size 6
Tensor a = tf.constant(new float[] { 1.0f, -0.5f, 3.4f, -2.1f, 0.0f, -6.5f });


+ 0
- 35
test/TensorFlowNET.UnitTest/ManagedAPI/TFNetApiTest.cs View File

@@ -1,35 +0,0 @@
using System;

namespace TensorFlowNET.UnitTest
{
public class TFNetApiTest
{
public bool Equal(float[] f1, float[] f2)
{
bool ret = false;
var tolerance = .000001f;
for (var i = 0; i < f1.Length; i++)
{
ret = Math.Abs(f1[i] - f2[i]) <= tolerance;
if (!ret)
break;
}

return ret;
}

public bool Equal(double[] d1, double[] d2)
{
bool ret = false;
var tolerance = .000000000000001f;
for (var i = 0; i < d1.Length; i++)
{
ret = Math.Abs(d1[i] - d2[i]) <= tolerance;
if (!ret)
break;
}

return ret;
}
}
}

+ 0
- 86
test/TensorFlowNET.UnitTest/ManagedAPI/ZeroFractionTest.cs View File

@@ -1,86 +0,0 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using System;
using System.Linq;
using Tensorflow;
using Tensorflow.UnitTest;

namespace TensorFlowNET.UnitTest.nn_test
{
[TestClass]
public class ZeroFractionTest : GraphModeTestBase
{
protected double _ZeroFraction(NDArray x)
{
assert(x.shape);
int total_elements = np.prod(x.shape);

var eps = 1e-8;
var nonzeros = x.Data<double>().Count(d => Math.Abs(d) > eps);
return 1.0 - nonzeros / (double)total_elements;
}

[Ignore("TODO implement nn_impl.zero_fraction")]
[TestMethod]
public void testZeroFraction()
{
var x_shape = new Shape(5, 17);
var x_np = np.random.randint(0, 2, x_shape);
//x_np.astype(np.float32);
var y_np = this._ZeroFraction(x_np);

var x_tf = constant_op.constant(x_np);
x_tf.set_shape(x_shape);
var y_tf = nn_impl.zero_fraction(x_tf);
var y_tf_np = self.evaluate<NDArray>(y_tf);

var eps = 1e-8;
self.assertAllClose(y_tf_np, y_np, eps);
}

[Ignore("TODO implement nn_impl.zero_fraction")]
[TestMethod]
public void testZeroFractionEmpty()
{

var x = np.zeros(0);
var y = self.evaluate<NDArray>(nn_impl.zero_fraction(new Tensor(x)));
self.assertTrue(np.isnan(y));
}

[Ignore("TODO implement nn_impl.zero_fraction")]
[TestMethod]
public void testZeroFraction2_27Zeros()
{
var sparsity = nn_impl.zero_fraction(
array_ops.zeros(new Shape((int)Math.Pow(2, 27 * 1.01)), dtypes.int8));
self.assertAllClose(1.0, self.evaluate<NDArray>(sparsity));
}

[Ignore("TODO implement nn_impl.zero_fraction")]
[TestMethod]
public void testZeroFraction2_27Ones()
{
var sparsity = nn_impl.zero_fraction(
array_ops.ones(new TensorShape((int)Math.Pow(2, 27 * 1.01)), dtypes.int8));
self.assertAllClose(0.0, self.evaluate<NDArray>(sparsity));
}

[Ignore("TODO implement nn_impl.zero_fraction")]
[TestMethod]
public void testUnknownSize()
{
var value = array_ops.placeholder(dtype: dtypes.float32);
var sparsity = nn_impl.zero_fraction(value);
using (var sess = self.cached_session())
{
// TODO: make this compile
//self.assertAllClose(
// 0.25,
// sess.run(sparsity, {value: [[0., 1.], [0.3, 2.]]}));
}
}


}
}

+ 1
- 0
test/TensorFlowNET.UnitTest/MultithreadingTests.cs View File

@@ -73,6 +73,7 @@ namespace TensorFlowNET.UnitTest
tf.peak_default_graph().Should().BeNull();
var beforehand = tf.get_default_graph(); //this should create default automatically.
beforehand.graph_key.Should().NotContain("-0/", "Already created a graph in an other thread.");
beforehand.as_default();
tf.peak_default_graph().Should().NotBeNull();

using (var sess = tf.Session())


+ 12
- 3
test/TensorFlowNET.UnitTest/Utilities/TestHelper.cs View File

@@ -1,4 +1,5 @@
using System.IO;
using System;
using System.IO;

namespace TensorFlowNET.UnitTest
{
@@ -6,8 +7,16 @@ namespace TensorFlowNET.UnitTest
{
public static string GetFullPathFromDataDir(string fileName)
{
var dir = Path.Combine(Directory.GetCurrentDirectory(), "..", "..", "..", "..", "..", "data");
return Path.GetFullPath(Path.Combine(dir, fileName));
var dataDir = GetRootContentDir(Directory.GetCurrentDirectory());
return Path.Combine(dataDir, fileName);
}

static string GetRootContentDir(string dir)
{
var path = Path.GetFullPath(Path.Combine(dir, "data"));
if (Directory.Exists(path))
return path;
return GetRootContentDir(Path.GetFullPath(Path.Combine(dir, "..")));
}
}
}

+ 0
- 1059
test/TensorFlowNET.UnitTest/control_flow_ops_test/control_flow_ops_test.py
File diff suppressed because it is too large
View File


+ 0
- 875
test/TensorFlowNET.UnitTest/nest_test/NestTest.cs View File

@@ -1,875 +0,0 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Newtonsoft.Json.Linq;
using NumSharp;
using System;
using System.Collections;
using System.Collections.Generic;
using Tensorflow;
using Tensorflow.UnitTest;
using Tensorflow.Util;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.nest_test
{
/// <summary>
/// excerpt of tensorflow/python/framework/util/nest_test.py
/// </summary>
[TestClass]
public class NestTest : GraphModeTestBase
{
[TestInitialize]
public void TestInitialize()
{
tf.Graph().as_default();
}

//public class PointXY
//{
// public double x;
// public double y;
//}

// if attr:
// class BadAttr(object):
// """Class that has a non-iterable __attrs_attrs__."""
// __attrs_attrs__ = None

// @attr.s
// class SampleAttr(object):
// field1 = attr.ib()
// field2 = attr.ib()

// @test_util.assert_no_new_pyobjects_executing_eagerly
// def testAttrsFlattenAndPack(self) :
// if attr is None:
// self.skipTest("attr module is unavailable.")

// field_values = [1, 2]
// sample_attr = NestTest.SampleAttr(* field_values)
// self.assertFalse(nest._is_attrs(field_values))
// self.assertTrue(nest._is_attrs(sample_attr))
// flat = nest.flatten(sample_attr)
// self.assertEqual(field_values, flat)
// restructured_from_flat = nest.pack_sequence_as(sample_attr, flat)
// self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr)
// self.assertEqual(restructured_from_flat, sample_attr)

//# Check that flatten fails if attributes are not iterable
// with self.assertRaisesRegexp(TypeError, "object is not iterable"):
// flat = nest.flatten(NestTest.BadAttr())
[Ignore]
[TestMethod]
public void testFlattenAndPack()
{
object structure = new object[] { new object[] { 3, 4 }, 5, new object[] { 6, 7, new object[] { 9, 10 }, 8 } };
var flat = new List<object> { "a", "b", "c", "d", "e", "f", "g", "h" };

self.assertEqual(nest.flatten(structure), new[] { 3, 4, 5, 6, 7, 9, 10, 8 });
self.assertEqual(JArray.FromObject(nest.pack_sequence_as(structure, flat)).ToString(),
JArray.FromObject(new object[] { new object[] { "a", "b" }, "c", new object[] { "d", "e", new object[] { "f", "g" }, "h" } }).ToString());
structure = new object[] { new Hashtable { ["x"] = 4, ["y"] = 2 }, new object[] { new object[] { new Hashtable { ["x"] = 1, ["y"] = 0 }, }, } };
flat = new List<object> { 4, 2, 1, 0 };
self.assertEqual(nest.flatten(structure), flat);
var restructured_from_flat = nest.pack_sequence_as(structure, flat) as object[];
//Console.WriteLine(JArray.FromObject(restructured_from_flat));
self.assertEqual(restructured_from_flat, structure);
self.assertEqual((restructured_from_flat[0] as Hashtable)["x"], 4);
self.assertEqual((restructured_from_flat[0] as Hashtable)["y"], 2);
self.assertEqual((((restructured_from_flat[1] as object[])[0] as object[])[0] as Hashtable)["x"], 1);
self.assertEqual((((restructured_from_flat[1] as object[])[0] as object[])[0] as Hashtable)["y"], 0);

self.assertEqual(new List<object> { 5 }, nest.flatten(5));
var flat1 = nest.flatten(np.array(new[] { 5 }));
self.assertEqual(new object[] { np.array(new int[] { 5 }) }, flat1);

self.assertEqual("a", nest.pack_sequence_as(5, new List<object> { "a" }));
self.assertEqual(np.array(new[] { 5 }),
nest.pack_sequence_as("scalar", new List<object> { np.array(new[] { 5 }) }));

Assert.ThrowsException<ValueError>(() => nest.pack_sequence_as("scalar", new List<object>() { 4, 5 }));

Assert.ThrowsException<ValueError>(() =>
nest.pack_sequence_as(new object[] { 5, 6, new object[] { 7, 8 } }, new List<object> { "a", "b", "c" }));
}

// @parameterized.parameters({"mapping_type": collections.OrderedDict
// },
// {"mapping_type": _CustomMapping
//})
// @test_util.assert_no_new_pyobjects_executing_eagerly
// def testFlattenDictOrder(self, mapping_type) :
// """`flatten` orders dicts by key, including OrderedDicts."""
// ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
// plain = {"d": 3, "b": 1, "a": 0, "c": 2}
// ordered_flat = nest.flatten(ordered)
// plain_flat = nest.flatten(plain)
// self.assertEqual([0, 1, 2, 3], ordered_flat)
// self.assertEqual([0, 1, 2, 3], plain_flat)

// @parameterized.parameters({"mapping_type": collections.OrderedDict},
// {"mapping_type": _CustomMapping})
// def testPackDictOrder(self, mapping_type):
// """Packing orders dicts by key, including OrderedDicts."""
// custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
// plain = {"d": 0, "b": 0, "a": 0, "c": 0}
// seq = [0, 1, 2, 3]
//custom_reconstruction = nest.pack_sequence_as(custom, seq)
//plain_reconstruction = nest.pack_sequence_as(plain, seq)
// self.assertIsInstance(custom_reconstruction, mapping_type)
// self.assertIsInstance(plain_reconstruction, dict)
// self.assertEqual(
// mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
// custom_reconstruction)
// self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)

// Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name

// @test_util.assert_no_new_pyobjects_executing_eagerly
// def testFlattenAndPack_withDicts(self) :
// # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
// mess = [
// "z",
// NestTest.Abc(3, 4), {
// "d": _CustomMapping({
// 41: 4
// }),
// "c": [
// 1,
// collections.OrderedDict([
// ("b", 3),
// ("a", 2),
// ]),
// ],
// "b": 5
// }, 17
// ]

// flattened = nest.flatten(mess)
// self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17])

// structure_of_mess = [
// 14,
// NestTest.Abc("a", True),
// {
// "d": _CustomMapping({
// 41: 42
// }),
// "c": [
// 0,
// collections.OrderedDict([
// ("b", 9),
// ("a", 8),
// ]),
// ],
// "b": 3
// },
// "hi everybody",
// ]

// unflattened = nest.pack_sequence_as(structure_of_mess, flattened)
// self.assertEqual(unflattened, mess)

// # Check also that the OrderedDict was created, with the correct key order.
//unflattened_ordered_dict = unflattened[2]["c"][1]
// self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict)
// self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])

// unflattened_custom_mapping = unflattened[2]["d"]
// self.assertIsInstance(unflattened_custom_mapping, _CustomMapping)
// self.assertEqual(list(unflattened_custom_mapping.keys()), [41])

[TestMethod]
public void testFlatten_numpyIsNotFlattened()
{
var structure = np.array(1, 2, 3);
var flattened = nest.flatten(structure);
self.assertEqual(len(flattened), 1);
}

[TestMethod]
public void testFlatten_stringIsNotFlattened()
{
var structure = "lots of letters";
var flattened = nest.flatten(structure);
self.assertEqual(len(flattened), 1);
var unflattened = nest.pack_sequence_as("goodbye", flattened);
self.assertEqual(structure, unflattened);
}

// def testPackSequenceAs_notIterableError(self) :
// with self.assertRaisesRegexp(TypeError,
// "flat_sequence must be a sequence"):
// nest.pack_sequence_as("hi", "bye")

[TestMethod]
public void testPackSequenceAs_wrongLengthsError()
{
Assert.ThrowsException<ValueError>(() =>
{
// with self.assertRaisesRegexp(
// ValueError,
// "Structure had 2 elements, but flat_sequence had 3 elements."):
nest.pack_sequence_as(new object[] { "hello", "world" }, new object[] { "and", "goodbye", "again" });
});
}

[Ignore]
[TestMethod]
public void testIsSequence()
{
self.assertFalse(nest.is_sequence("1234"));
self.assertTrue(nest.is_sequence(new object[] { 1, 3, new object[] { 4, 5 } }));
// TODO: ValueTuple<T,T>
//self.assertTrue(nest.is_sequence(((7, 8), (5, 6))));
self.assertTrue(nest.is_sequence(new object[] { }));
self.assertTrue(nest.is_sequence(new Hashtable { ["a"] = 1, ["b"] = 2 }));
self.assertFalse(nest.is_sequence(new HashSet<int> { 1, 2 }));
var ones = array_ops.ones(new int[] { 2, 3 });
self.assertFalse(nest.is_sequence(ones));
self.assertFalse(nest.is_sequence(gen_math_ops.tanh(ones)));
self.assertFalse(nest.is_sequence(np.ones(new int[] { 4, 5 })));
}

// @parameterized.parameters({"mapping_type": _CustomMapping},
// {"mapping_type": dict})
// def testFlattenDictItems(self, mapping_type):
// dictionary = mapping_type({ (4, 5, (6, 8)): ("a", "b", ("c", "d"))})
// flat = {4: "a", 5: "b", 6: "c", 8: "d"}
// self.assertEqual(nest.flatten_dict_items(dictionary), flat)

// with self.assertRaises(TypeError):
// nest.flatten_dict_items(4)

// bad_dictionary = mapping_type({ (4, 5, (4, 8)): ("a", "b", ("c", "d"))})
// with self.assertRaisesRegexp(ValueError, "not unique"):
// nest.flatten_dict_items(bad_dictionary)

// another_bad_dictionary = mapping_type({
// (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e")))
// })
// with self.assertRaisesRegexp(
// ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"):
// nest.flatten_dict_items(another_bad_dictionary)

//# pylint does not correctly recognize these as class names and
//# suggests to use variable style under_score naming.
//# pylint: disable=invalid-name
// Named0ab = collections.namedtuple("named_0", ("a", "b"))
// Named1ab = collections.namedtuple("named_1", ("a", "b"))
// SameNameab = collections.namedtuple("same_name", ("a", "b"))
// SameNameab2 = collections.namedtuple("same_name", ("a", "b"))
// SameNamexy = collections.namedtuple("same_name", ("x", "y"))
// SameName1xy = collections.namedtuple("same_name_1", ("x", "y"))
// SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y"))
// NotSameName = collections.namedtuple("not_same_name", ("a", "b"))
// # pylint: enable=invalid-name

// class SameNamedType1(SameNameab):
// pass

// @test_util.assert_no_new_pyobjects_executing_eagerly
// def testAssertSameStructure(self):
// structure1 = (((1, 2), 3), 4, (5, 6))
// structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
// structure_different_num_elements = ("spam", "eggs")
// structure_different_nesting = (((1, 2), 3), 4, 5, (6,))
// nest.assert_same_structure(structure1, structure2)
// nest.assert_same_structure("abc", 1.0)
// nest.assert_same_structure("abc", np.array([0, 1]))
// nest.assert_same_structure("abc", constant_op.constant([0, 1]))

// with self.assertRaisesRegexp(
// ValueError,
// ("The two structures don't have the same nested structure\\.\n\n"
// "First structure:.*?\n\n"
// "Second structure:.*\n\n"
// "More specifically: Substructure "
// r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while '
// 'substructure "type=str str=spam" is not\n'
// "Entire first structure:\n"
// r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n"
// "Entire second structure:\n"
// r"\(\., \.\)")):
// nest.assert_same_structure(structure1, structure_different_num_elements)

// with self.assertRaisesRegexp(
// ValueError,
// ("The two structures don't have the same nested structure\\.\n\n"
// "First structure:.*?\n\n"
// "Second structure:.*\n\n"
// r'More specifically: Substructure "type=list str=\[0, 1\]" '
// r'is a sequence, while substructure "type=ndarray str=\[0 1\]" '
// "is not")):
// nest.assert_same_structure([0, 1], np.array([0, 1]))

// with self.assertRaisesRegexp(
// ValueError,
// ("The two structures don't have the same nested structure\\.\n\n"
// "First structure:.*?\n\n"
// "Second structure:.*\n\n"
// r'More specifically: Substructure "type=list str=\[0, 1\]" '
// 'is a sequence, while substructure "type=int str=0" '
// "is not")):
// nest.assert_same_structure(0, [0, 1])

// self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1])

// with self.assertRaisesRegexp(
// ValueError,
// ("don't have the same nested structure\\.\n\n"
// "First structure: .*?\n\nSecond structure: ")):
// nest.assert_same_structure(structure1, structure_different_nesting)

// self.assertRaises(TypeError, nest.assert_same_structure, (0, 1),
// NestTest.Named0ab("a", "b"))

// nest.assert_same_structure(NestTest.Named0ab(3, 4),
// NestTest.Named0ab("a", "b"))

// self.assertRaises(TypeError, nest.assert_same_structure,
// NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4))

// with self.assertRaisesRegexp(
// ValueError,
// ("don't have the same nested structure\\.\n\n"
// "First structure: .*?\n\nSecond structure: ")):
// nest.assert_same_structure(NestTest.Named0ab(3, 4),
// NestTest.Named0ab([3], 4))

// with self.assertRaisesRegexp(
// ValueError,
// ("don't have the same nested structure\\.\n\n"
// "First structure: .*?\n\nSecond structure: ")):
// nest.assert_same_structure([[3], 4], [3, [4]])

// structure1_list = [[[1, 2], 3], 4, [5, 6]]
// with self.assertRaisesRegexp(TypeError,
// "don't have the same sequence type"):
// nest.assert_same_structure(structure1, structure1_list)
// nest.assert_same_structure(structure1, structure2, check_types= False)
// nest.assert_same_structure(structure1, structure1_list, check_types=False)

// with self.assertRaisesRegexp(ValueError,
// "don't have the same set of keys"):
// nest.assert_same_structure({"a": 1}, {"b": 1})

// nest.assert_same_structure(NestTest.SameNameab(0, 1),
// NestTest.SameNameab2(2, 3))

// # This assertion is expected to pass: two namedtuples with the same
// # name and field names are considered to be identical.
// nest.assert_same_structure(
// NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2),
// NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4))

// expected_message = "The two structures don't have the same.*"
// with self.assertRaisesRegexp(ValueError, expected_message):
// nest.assert_same_structure(
// NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)),
// NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2))

// self.assertRaises(TypeError, nest.assert_same_structure,
// NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3))

// self.assertRaises(TypeError, nest.assert_same_structure,
// NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3))

// self.assertRaises(TypeError, nest.assert_same_structure,
// NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3))

// EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name

// def testHeterogeneousComparison(self):
// nest.assert_same_structure({"a": 4}, _CustomMapping(a= 3))
// nest.assert_same_structure(_CustomMapping(b=3), {"b": 4})
[Ignore]
[TestMethod]
public void testMapStructure()
{
var structure1 = new object[] { new object[] { new object[] { 1, 2 }, 3 }, 4, new object[] { 5, 6 } };
var structure2 = new object[] { new object[] { new object[] { 7, 8 }, 9 }, 10, new object[] { 11, 12 } };
var structure1_plus1 = nest.map_structure(x => (int)x + 1, structure1);
var structure1_strings = nest.map_structure(x => $"{x}", structure1);
var s = JArray.FromObject(structure1_plus1).ToString();
Console.WriteLine(s);
// nest.assert_same_structure(structure1, structure1_plus1)
self.assertAllEqual(nest.flatten(structure1_plus1), new object[] { 2, 3, 4, 5, 6, 7 });
self.assertAllEqual(nest.flatten(structure1_strings), new object[] { "1", "2", "3", "4", "5", "6" });
var structure1_plus_structure2 = nest.map_structure(x => (int)(x[0]) + (int)(x[1]), structure1, structure2);
self.assertEqual(
new object[] { new object[] { new object[] { 1 + 7, 2 + 8 }, 3 + 9 }, 4 + 10, new object[] { 5 + 11, 6 + 12 } },
structure1_plus_structure2);

// self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))

// self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))

// # Empty structures
// self.assertEqual((), nest.map_structure(lambda x: x + 1, ()))
// self.assertEqual([], nest.map_structure(lambda x: x + 1, []))
// self.assertEqual({}, nest.map_structure(lambda x: x + 1, {}))
// self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1,
// NestTest.EmptyNT()))

// # This is checking actual equality of types, empty list != empty tuple
// self.assertNotEqual((), nest.map_structure(lambda x: x + 1, []))

// with self.assertRaisesRegexp(TypeError, "callable"):
// nest.map_structure("bad", structure1_plus1)

// with self.assertRaisesRegexp(ValueError, "at least one structure"):
// nest.map_structure(lambda x: x)

// with self.assertRaisesRegexp(ValueError, "same number of elements"):
// nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5))

// with self.assertRaisesRegexp(ValueError, "same nested structure"):
// nest.map_structure(lambda x, y: None, 3, (3,))

// with self.assertRaisesRegexp(TypeError, "same sequence type"):
// nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5])

// with self.assertRaisesRegexp(ValueError, "same nested structure"):
// nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))

// structure1_list = [[[1, 2], 3], 4, [5, 6]]
// with self.assertRaisesRegexp(TypeError, "same sequence type"):
// nest.map_structure(lambda x, y: None, structure1, structure1_list)

// nest.map_structure(lambda x, y: None, structure1, structure1_list,
// check_types=False)

// with self.assertRaisesRegexp(ValueError, "same nested structure"):
// nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)),
// check_types=False)

// with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
// nest.map_structure(lambda x: None, structure1, foo="a")

// with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
// nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")

// ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name
}

// @test_util.assert_no_new_pyobjects_executing_eagerly
// def testMapStructureWithStrings(self) :
// inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz"))
// inp_b = NestTest.ABTuple(a=2, b=(1, 3))
// out = nest.map_structure(lambda string, repeats: string* repeats,
// inp_a,
// inp_b)
// self.assertEqual("foofoo", out.a)
// self.assertEqual("bar", out.b[0])
// self.assertEqual("bazbazbaz", out.b[1])

// nt = NestTest.ABTuple(a=("something", "something_else"),
// b="yet another thing")
// rev_nt = nest.map_structure(lambda x: x[::- 1], nt)
// # Check the output is the correct structure, and all strings are reversed.
// nest.assert_same_structure(nt, rev_nt)
// self.assertEqual(nt.a[0][::- 1], rev_nt.a[0])
// self.assertEqual(nt.a[1][::- 1], rev_nt.a[1])
// self.assertEqual(nt.b[::- 1], rev_nt.b)

// @test_util.run_deprecated_v1
// def testMapStructureOverPlaceholders(self) :
// inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]),
// array_ops.placeholder(dtypes.float32, shape=[3, 7]))
// inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]),
// array_ops.placeholder(dtypes.float32, shape=[3, 7]))

// output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b)

// nest.assert_same_structure(output, inp_a)
// self.assertShapeEqual(np.zeros((3, 4)), output[0])
// self.assertShapeEqual(np.zeros((3, 7)), output[1])

// feed_dict = {
// inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)),
// inp_b: (np.random.randn(3, 4), np.random.randn(3, 7))
// }

// with self.cached_session() as sess:
// output_np = sess.run(output, feed_dict=feed_dict)
// self.assertAllClose(output_np[0],
// feed_dict[inp_a][0] + feed_dict[inp_b][0])
// self.assertAllClose(output_np[1],
// feed_dict[inp_a][1] + feed_dict[inp_b][1])

// def testAssertShallowStructure(self):
// inp_ab = ["a", "b"]
//inp_abc = ["a", "b", "c"]
//expected_message = (
// "The two structures don't have the same sequence length. Input "
// "structure has length 2, while shallow structure has length 3.")
// with self.assertRaisesRegexp(ValueError, expected_message):
// nest.assert_shallow_structure(inp_abc, inp_ab)

// inp_ab1 = [(1, 1), (2, 2)]
// inp_ab2 = [[1, 1], [2, 2]]
// expected_message = (
// "The two structures don't have the same sequence type. Input structure "
// "has type <(type|class) 'tuple'>, while shallow structure has type "
// "<(type|class) 'list'>.")
// with self.assertRaisesRegexp(TypeError, expected_message):
// nest.assert_shallow_structure(inp_ab2, inp_ab1)
// nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types= False)

// inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}}
// inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}}
// expected_message = (
// r"The two structures don't have the same keys. Input "
// r"structure has keys \['c'\], while shallow structure has "
// r"keys \['d'\].")

// with self.assertRaisesRegexp(ValueError, expected_message):
// nest.assert_shallow_structure(inp_ab2, inp_ab1)

// inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))])
// inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)])
// nest.assert_shallow_structure(inp_ab, inp_ba)

// # This assertion is expected to pass: two namedtuples with the same
//# name and field names are considered to be identical.
//inp_shallow = NestTest.SameNameab(1, 2)
// inp_deep = NestTest.SameNameab2(1, [1, 2, 3])
// nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False)
// nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True)

// def testFlattenUpTo(self):
// # Shallow tree ends at scalar.
// input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
// shallow_tree = [[True, True], [False, True]]
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]])
// self.assertEqual(flattened_shallow_tree, [True, True, False, True])

//# Shallow tree ends at string.
// input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
// shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
// input_tree)
// input_tree_flattened = nest.flatten(input_tree)
// self.assertEqual(input_tree_flattened_as_shallow_tree,
// [("a", 1), ("b", 2), ("c", 3), ("d", 4)])
// self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])

// # Make sure dicts are correctly flattened, yielding values, not keys.
//input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
// shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
// input_tree)
// self.assertEqual(input_tree_flattened_as_shallow_tree,
// [1, { "c": 2}, 3, (4, 5)])

// # Namedtuples.
// ab_tuple = NestTest.ABTuple
// input_tree = ab_tuple(a =[0, 1], b = 2)
// shallow_tree = ab_tuple(a= 0, b= 1)
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
// input_tree)
// self.assertEqual(input_tree_flattened_as_shallow_tree,
// [[0, 1], 2])

// # Nested dicts, OrderedDicts and namedtuples.
// input_tree = collections.OrderedDict(
// [("a", ab_tuple(a =[0, {"b": 1}], b=2)),
// ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})])
// shallow_tree = input_tree
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
// input_tree)
// self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
// shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})])
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
// input_tree)
// self.assertEqual(input_tree_flattened_as_shallow_tree,
// [ab_tuple(a =[0, { "b": 1}], b=2),
// 3,
// collections.OrderedDict([("f", 4)])])
// shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
// input_tree)
// self.assertEqual(input_tree_flattened_as_shallow_tree,
// [ab_tuple(a =[0, {"b": 1}], b=2),
// {"d": 3, "e": collections.OrderedDict([("f", 4)])}])

// ## Shallow non-list edge-case.
// # Using iterable elements.
// input_tree = ["input_tree"]
//shallow_tree = "shallow_tree"
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_input_tree, [input_tree])
// self.assertEqual(flattened_shallow_tree, [shallow_tree])

// input_tree = ["input_tree_0", "input_tree_1"]
//shallow_tree = "shallow_tree"
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_input_tree, [input_tree])
// self.assertEqual(flattened_shallow_tree, [shallow_tree])

// # Using non-iterable elements.
//input_tree = [0]
//shallow_tree = 9
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_input_tree, [input_tree])
// self.assertEqual(flattened_shallow_tree, [shallow_tree])

// input_tree = [0, 1]
//shallow_tree = 9
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_input_tree, [input_tree])
// self.assertEqual(flattened_shallow_tree, [shallow_tree])

// ## Both non-list edge-case.
//# Using iterable elements.
//input_tree = "input_tree"
// shallow_tree = "shallow_tree"
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_input_tree, [input_tree])
// self.assertEqual(flattened_shallow_tree, [shallow_tree])

// # Using non-iterable elements.
//input_tree = 0
// shallow_tree = 0
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_input_tree, [input_tree])
// self.assertEqual(flattened_shallow_tree, [shallow_tree])

// ## Input non-list edge-case.
//# Using iterable elements.
//input_tree = "input_tree"
// shallow_tree = ["shallow_tree"]
//expected_message = ("If shallow structure is a sequence, input must also "
// "be a sequence. Input has type: <(type|class) 'str'>.")
// with self.assertRaisesRegexp(TypeError, expected_message):
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_shallow_tree, shallow_tree)

// input_tree = "input_tree"
// shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
//with self.assertRaisesRegexp(TypeError, expected_message):
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_shallow_tree, shallow_tree)

//# Using non-iterable elements.
// input_tree = 0
// shallow_tree = [9]
//expected_message = ("If shallow structure is a sequence, input must also "
// "be a sequence. Input has type: <(type|class) 'int'>.")
// with self.assertRaisesRegexp(TypeError, expected_message):
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_shallow_tree, shallow_tree)

// input_tree = 0
// shallow_tree = [9, 8]
//with self.assertRaisesRegexp(TypeError, expected_message):
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_shallow_tree, shallow_tree)

// def testMapStructureUpTo(self) :
// # Named tuples.
// ab_tuple = collections.namedtuple("ab_tuple", "a, b")
// op_tuple = collections.namedtuple("op_tuple", "add, mul")
// inp_val = ab_tuple(a= 2, b= 3)
// inp_ops = ab_tuple(a= op_tuple(add = 1, mul = 2), b= op_tuple(add = 2, mul = 3))
// out = nest.map_structure_up_to(
// inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops)
// self.assertEqual(out.a, 6)
// self.assertEqual(out.b, 15)

// # Lists.
// data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
// name_list = ["evens", ["odds", "primes"]]
// out = nest.map_structure_up_to(
// name_list, lambda name, sec: "first_{}_{}".format(len(sec), name),
// name_list, data_list)
// self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]])

// # Dicts.
// inp_val = dict(a= 2, b= 3)
// inp_ops = dict(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3))
// out = nest.map_structure_up_to(
// inp_val,
// lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
// self.assertEqual(out["a"], 6)
// self.assertEqual(out["b"], 15)

// # Non-equal dicts.
// inp_val = dict(a= 2, b= 3)
// inp_ops = dict(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3))
// with self.assertRaisesRegexp(ValueError, "same keys"):
// nest.map_structure_up_to(
// inp_val,
// lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)

// # Dict+custom mapping.
// inp_val = dict(a= 2, b= 3)
// inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3))
// out = nest.map_structure_up_to(
// inp_val,
// lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
// self.assertEqual(out["a"], 6)
// self.assertEqual(out["b"], 15)

// # Non-equal dict/mapping.
// inp_val = dict(a= 2, b= 3)
// inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3))
// with self.assertRaisesRegexp(ValueError, "same keys"):
// nest.map_structure_up_to(
// inp_val,
// lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)

// def testGetTraverseShallowStructure(self):
// scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []]
// scalar_traverse_r = nest.get_traverse_shallow_structure(
// lambda s: not isinstance(s, tuple),
// scalar_traverse_input)
// self.assertEqual(scalar_traverse_r,
// [True, True, False, [True, True], {"a": False}, []])
// nest.assert_shallow_structure(scalar_traverse_r,
// scalar_traverse_input)

// structure_traverse_input = [(1, [2]), ([1], 2)]
// structure_traverse_r = nest.get_traverse_shallow_structure(
// lambda s: (True, False) if isinstance(s, tuple) else True,
// structure_traverse_input)
// self.assertEqual(structure_traverse_r,
// [(True, False), ([True], False)])
// nest.assert_shallow_structure(structure_traverse_r,
// structure_traverse_input)

// with self.assertRaisesRegexp(TypeError, "returned structure"):
// nest.get_traverse_shallow_structure(lambda _: [True], 0)

// with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"):
// nest.get_traverse_shallow_structure(lambda _: 1, [1])

// with self.assertRaisesRegexp(
// TypeError, "didn't return a depth=1 structure of bools"):
// nest.get_traverse_shallow_structure(lambda _: [1], [1])

// def testYieldFlatStringPaths(self):
// for inputs_expected in ({"inputs": [], "expected": []},
// {"inputs": 3, "expected": [()]},
// {"inputs": [3], "expected": [(0,)]},
// {"inputs": {"a": 3}, "expected": [("a",)]},
// {"inputs": {"a": {"b": 4}},
// "expected": [("a", "b")]},
// {"inputs": [{"a": 2}], "expected": [(0, "a")]},
// {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]},
// {"inputs": [{"a": [(23, 42)]}],
// "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]},
// {"inputs": [{"a": ([23], 42)}],
// "expected": [(0, "a", 0, 0), (0, "a", 1)]},
// {"inputs": {"a": {"a": 2}, "c": [[[4]]]},
// "expected": [("a", "a"), ("c", 0, 0, 0)]},
// {"inputs": {"0": [{"1": 23}]},
// "expected": [("0", 0, "1")]}):
// inputs = inputs_expected["inputs"]
// expected = inputs_expected["expected"]
// self.assertEqual(list(nest.yield_flat_paths(inputs)), expected)

// def testFlattenWithStringPaths(self):
// for inputs_expected in (
// {"inputs": [], "expected": []},
// {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]},
// {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}):
// inputs = inputs_expected["inputs"]
// expected = inputs_expected["expected"]
// self.assertEqual(
// nest.flatten_with_joined_string_paths(inputs, separator="/"),
// expected)

// # Need a separate test for namedtuple as we can't declare tuple definitions
// # in the @parameterized arguments.
// def testFlattenNamedTuple(self):
// # pylint: disable=invalid-name
// Foo = collections.namedtuple("Foo", ["a", "b"])
// Bar = collections.namedtuple("Bar", ["c", "d"])
// # pylint: enable=invalid-name
// test_cases = [
// (Foo(a = 3, b = Bar(c = 23, d = 42)),
// [("a", 3), ("b/c", 23), ("b/d", 42)]),
// (Foo(a = Bar(c = 23, d = 42), b = Bar(c = 0, d = "something")),
// [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]),
// (Bar(c = 42, d = 43),
// [("c", 42), ("d", 43)]),
// (Bar(c =[42], d = 43),
// [("c/0", 42), ("d", 43)]),
// ]
// for inputs, expected in test_cases:
// self.assertEqual(
// list(nest.flatten_with_joined_string_paths(inputs)), expected)

// @parameterized.named_parameters(
// ("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))),
// ("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True,
// {"a": ("a", 4), "b": ("b", 6)}),
// ("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))),
// ("nested",
// {"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True,
// {"a": [("a/0", 10), ("a/1", 12)],
// "b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]}))
// def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected):
// def format_sum(path, * values):
// return (path, sum(values))
// result = nest.map_structure_with_paths(format_sum, s1, s2,
// check_types=check_types)
// self.assertEqual(expected, result)

// @parameterized.named_parameters(
// ("tuples", (1, 2), (3, 4, 5), ValueError),
// ("dicts", {"a": 1}, {"b": 2}, ValueError),
// ("mixed", (1, 2), [3, 4], TypeError),
// ("nested",
// {"a": [2, 3], "b": [1, 3]},
// {"b": [5, 6, 7], "a": [8, 9]},
// ValueError
// ))
// def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type):
// with self.assertRaises(error_type):
// nest.map_structure_with_paths(lambda path, * s: 0, s1, s2)


//class NestBenchmark(test.Benchmark):

// def run_and_report(self, s1, s2, name):
// burn_iter, test_iter = 100, 30000

// for _ in xrange(burn_iter) :
// nest.assert_same_structure(s1, s2)

// t0 = time.time()
// for _ in xrange(test_iter) :
// nest.assert_same_structure(s1, s2)
// t1 = time.time()

// self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter,
// name=name)

// def benchmark_assert_structure(self):
// s1 = (((1, 2), 3), 4, (5, 6))
// s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
// self.run_and_report(s1, s2, "assert_same_structure_6_elem")

// s1 = (((1, 2), 3), 4, (5, 6)) * 10
// s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10
// self.run_and_report(s1, s2, "assert_same_structure_60_elem")


//if __name__ == "__main__":
// test.main()
}
}

+ 0
- 883
test/TensorFlowNET.UnitTest/nest_test/nest_test.py View File

@@ -1,883 +0,0 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for utilities working with arbitrarily nested structures."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import time

from absl.testing import parameterized
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
from tensorflow.python.util import nest

try:
import attr # pylint:disable=g-import-not-at-top
except ImportError:
attr = None


class _CustomMapping(collections.Mapping):

def __init__(self, *args, **kwargs):
self._wrapped = dict(*args, **kwargs)

def __getitem__(self, key):
return self._wrapped[key]

def __iter__(self):
return iter(self._wrapped)

def __len__(self):
return len(self._wrapped)


class NestTest(parameterized.TestCase, test.TestCase):

PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name

if attr:
class BadAttr(object):
"""Class that has a non-iterable __attrs_attrs__."""
__attrs_attrs__ = None

@attr.s
class SampleAttr(object):
field1 = attr.ib()
field2 = attr.ib()

@test_util.assert_no_new_pyobjects_executing_eagerly
def testAttrsFlattenAndPack(self):
if attr is None:
self.skipTest("attr module is unavailable.")

field_values = [1, 2]
sample_attr = NestTest.SampleAttr(*field_values)
self.assertFalse(nest._is_attrs(field_values))
self.assertTrue(nest._is_attrs(sample_attr))
flat = nest.flatten(sample_attr)
self.assertEqual(field_values, flat)
restructured_from_flat = nest.pack_sequence_as(sample_attr, flat)
self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr)
self.assertEqual(restructured_from_flat, sample_attr)

# Check that flatten fails if attributes are not iterable
with self.assertRaisesRegexp(TypeError, "object is not iterable"):
flat = nest.flatten(NestTest.BadAttr())

@test_util.assert_no_new_pyobjects_executing_eagerly
def testFlattenAndPack(self):
structure = ((3, 4), 5, (6, 7, (9, 10), 8))
flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8])
self.assertEqual(
nest.pack_sequence_as(structure, flat), (("a", "b"), "c",
("d", "e", ("f", "g"), "h")))
structure = (NestTest.PointXY(x=4, y=2),
((NestTest.PointXY(x=1, y=0),),))
flat = [4, 2, 1, 0]
self.assertEqual(nest.flatten(structure), flat)
restructured_from_flat = nest.pack_sequence_as(structure, flat)
self.assertEqual(restructured_from_flat, structure)
self.assertEqual(restructured_from_flat[0].x, 4)
self.assertEqual(restructured_from_flat[0].y, 2)
self.assertEqual(restructured_from_flat[1][0][0].x, 1)
self.assertEqual(restructured_from_flat[1][0][0].y, 0)

self.assertEqual([5], nest.flatten(5))
self.assertEqual([np.array([5])], nest.flatten(np.array([5])))

self.assertEqual("a", nest.pack_sequence_as(5, ["a"]))
self.assertEqual(
np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])]))

with self.assertRaisesRegexp(ValueError, "Structure is a scalar"):
nest.pack_sequence_as("scalar", [4, 5])

with self.assertRaisesRegexp(TypeError, "flat_sequence"):
nest.pack_sequence_as([4, 5], "bad_sequence")

with self.assertRaises(ValueError):
nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])

@parameterized.parameters({"mapping_type": collections.OrderedDict},
{"mapping_type": _CustomMapping})
@test_util.assert_no_new_pyobjects_executing_eagerly
def testFlattenDictOrder(self, mapping_type):
"""`flatten` orders dicts by key, including OrderedDicts."""
ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
plain = {"d": 3, "b": 1, "a": 0, "c": 2}
ordered_flat = nest.flatten(ordered)
plain_flat = nest.flatten(plain)
self.assertEqual([0, 1, 2, 3], ordered_flat)
self.assertEqual([0, 1, 2, 3], plain_flat)

@parameterized.parameters({"mapping_type": collections.OrderedDict},
{"mapping_type": _CustomMapping})
def testPackDictOrder(self, mapping_type):
"""Packing orders dicts by key, including OrderedDicts."""
custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
plain = {"d": 0, "b": 0, "a": 0, "c": 0}
seq = [0, 1, 2, 3]
custom_reconstruction = nest.pack_sequence_as(custom, seq)
plain_reconstruction = nest.pack_sequence_as(plain, seq)
self.assertIsInstance(custom_reconstruction, mapping_type)
self.assertIsInstance(plain_reconstruction, dict)
self.assertEqual(
mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
custom_reconstruction)
self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)

Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name

@test_util.assert_no_new_pyobjects_executing_eagerly
def testFlattenAndPack_withDicts(self):
# A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
mess = [
"z",
NestTest.Abc(3, 4), {
"d": _CustomMapping({
41: 4
}),
"c": [
1,
collections.OrderedDict([
("b", 3),
("a", 2),
]),
],
"b": 5
}, 17
]

flattened = nest.flatten(mess)
self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17])

structure_of_mess = [
14,
NestTest.Abc("a", True),
{
"d": _CustomMapping({
41: 42
}),
"c": [
0,
collections.OrderedDict([
("b", 9),
("a", 8),
]),
],
"b": 3
},
"hi everybody",
]

unflattened = nest.pack_sequence_as(structure_of_mess, flattened)
self.assertEqual(unflattened, mess)

# Check also that the OrderedDict was created, with the correct key order.
unflattened_ordered_dict = unflattened[2]["c"][1]
self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict)
self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])

unflattened_custom_mapping = unflattened[2]["d"]
self.assertIsInstance(unflattened_custom_mapping, _CustomMapping)
self.assertEqual(list(unflattened_custom_mapping.keys()), [41])

def testFlatten_numpyIsNotFlattened(self):
structure = np.array([1, 2, 3])
flattened = nest.flatten(structure)
self.assertEqual(len(flattened), 1)

def testFlatten_stringIsNotFlattened(self):
structure = "lots of letters"
flattened = nest.flatten(structure)
self.assertEqual(len(flattened), 1)
unflattened = nest.pack_sequence_as("goodbye", flattened)
self.assertEqual(structure, unflattened)

def testPackSequenceAs_notIterableError(self):
with self.assertRaisesRegexp(TypeError,
"flat_sequence must be a sequence"):
nest.pack_sequence_as("hi", "bye")

def testPackSequenceAs_wrongLengthsError(self):
with self.assertRaisesRegexp(
ValueError,
"Structure had 2 elements, but flat_sequence had 3 elements."):
nest.pack_sequence_as(["hello", "world"],
["and", "goodbye", "again"])

@test_util.assert_no_new_pyobjects_executing_eagerly
def testIsSequence(self):
self.assertFalse(nest.is_sequence("1234"))
self.assertTrue(nest.is_sequence([1, 3, [4, 5]]))
self.assertTrue(nest.is_sequence(((7, 8), (5, 6))))
self.assertTrue(nest.is_sequence([]))
self.assertTrue(nest.is_sequence({"a": 1, "b": 2}))
self.assertFalse(nest.is_sequence(set([1, 2])))
ones = array_ops.ones([2, 3])
self.assertFalse(nest.is_sequence(ones))
self.assertFalse(nest.is_sequence(math_ops.tanh(ones)))
self.assertFalse(nest.is_sequence(np.ones((4, 5))))

@parameterized.parameters({"mapping_type": _CustomMapping},
{"mapping_type": dict})
def testFlattenDictItems(self, mapping_type):
dictionary = mapping_type({(4, 5, (6, 8)): ("a", "b", ("c", "d"))})
flat = {4: "a", 5: "b", 6: "c", 8: "d"}
self.assertEqual(nest.flatten_dict_items(dictionary), flat)

with self.assertRaises(TypeError):
nest.flatten_dict_items(4)

bad_dictionary = mapping_type({(4, 5, (4, 8)): ("a", "b", ("c", "d"))})
with self.assertRaisesRegexp(ValueError, "not unique"):
nest.flatten_dict_items(bad_dictionary)

another_bad_dictionary = mapping_type({
(4, 5, (6, 8)): ("a", "b", ("c", ("d", "e")))
})
with self.assertRaisesRegexp(
ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"):
nest.flatten_dict_items(another_bad_dictionary)

# pylint does not correctly recognize these as class names and
# suggests to use variable style under_score naming.
# pylint: disable=invalid-name
Named0ab = collections.namedtuple("named_0", ("a", "b"))
Named1ab = collections.namedtuple("named_1", ("a", "b"))
SameNameab = collections.namedtuple("same_name", ("a", "b"))
SameNameab2 = collections.namedtuple("same_name", ("a", "b"))
SameNamexy = collections.namedtuple("same_name", ("x", "y"))
SameName1xy = collections.namedtuple("same_name_1", ("x", "y"))
SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y"))
NotSameName = collections.namedtuple("not_same_name", ("a", "b"))
# pylint: enable=invalid-name

class SameNamedType1(SameNameab):
pass

@test_util.assert_no_new_pyobjects_executing_eagerly
def testAssertSameStructure(self):
structure1 = (((1, 2), 3), 4, (5, 6))
structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
structure_different_num_elements = ("spam", "eggs")
structure_different_nesting = (((1, 2), 3), 4, 5, (6,))
nest.assert_same_structure(structure1, structure2)
nest.assert_same_structure("abc", 1.0)
nest.assert_same_structure("abc", np.array([0, 1]))
nest.assert_same_structure("abc", constant_op.constant([0, 1]))

with self.assertRaisesRegexp(
ValueError,
("The two structures don't have the same nested structure\\.\n\n"
"First structure:.*?\n\n"
"Second structure:.*\n\n"
"More specifically: Substructure "
r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while '
'substructure "type=str str=spam" is not\n'
"Entire first structure:\n"
r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n"
"Entire second structure:\n"
r"\(\., \.\)")):
nest.assert_same_structure(structure1, structure_different_num_elements)

with self.assertRaisesRegexp(
ValueError,
("The two structures don't have the same nested structure\\.\n\n"
"First structure:.*?\n\n"
"Second structure:.*\n\n"
r'More specifically: Substructure "type=list str=\[0, 1\]" '
r'is a sequence, while substructure "type=ndarray str=\[0 1\]" '
"is not")):
nest.assert_same_structure([0, 1], np.array([0, 1]))

with self.assertRaisesRegexp(
ValueError,
("The two structures don't have the same nested structure\\.\n\n"
"First structure:.*?\n\n"
"Second structure:.*\n\n"
r'More specifically: Substructure "type=list str=\[0, 1\]" '
'is a sequence, while substructure "type=int str=0" '
"is not")):
nest.assert_same_structure(0, [0, 1])

self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1])

with self.assertRaisesRegexp(
ValueError,
("don't have the same nested structure\\.\n\n"
"First structure: .*?\n\nSecond structure: ")):
nest.assert_same_structure(structure1, structure_different_nesting)

self.assertRaises(TypeError, nest.assert_same_structure, (0, 1),
NestTest.Named0ab("a", "b"))

nest.assert_same_structure(NestTest.Named0ab(3, 4),
NestTest.Named0ab("a", "b"))

self.assertRaises(TypeError, nest.assert_same_structure,
NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4))

with self.assertRaisesRegexp(
ValueError,
("don't have the same nested structure\\.\n\n"
"First structure: .*?\n\nSecond structure: ")):
nest.assert_same_structure(NestTest.Named0ab(3, 4),
NestTest.Named0ab([3], 4))

with self.assertRaisesRegexp(
ValueError,
("don't have the same nested structure\\.\n\n"
"First structure: .*?\n\nSecond structure: ")):
nest.assert_same_structure([[3], 4], [3, [4]])

structure1_list = [[[1, 2], 3], 4, [5, 6]]
with self.assertRaisesRegexp(TypeError,
"don't have the same sequence type"):
nest.assert_same_structure(structure1, structure1_list)
nest.assert_same_structure(structure1, structure2, check_types=False)
nest.assert_same_structure(structure1, structure1_list, check_types=False)

with self.assertRaisesRegexp(ValueError,
"don't have the same set of keys"):
nest.assert_same_structure({"a": 1}, {"b": 1})

nest.assert_same_structure(NestTest.SameNameab(0, 1),
NestTest.SameNameab2(2, 3))

# This assertion is expected to pass: two namedtuples with the same
# name and field names are considered to be identical.
nest.assert_same_structure(
NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2),
NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4))

expected_message = "The two structures don't have the same.*"
with self.assertRaisesRegexp(ValueError, expected_message):
nest.assert_same_structure(
NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)),
NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2))

self.assertRaises(TypeError, nest.assert_same_structure,
NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3))

self.assertRaises(TypeError, nest.assert_same_structure,
NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3))

self.assertRaises(TypeError, nest.assert_same_structure,
NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3))

EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name

def testHeterogeneousComparison(self):
nest.assert_same_structure({"a": 4}, _CustomMapping(a=3))
nest.assert_same_structure(_CustomMapping(b=3), {"b": 4})

@test_util.assert_no_new_pyobjects_executing_eagerly
def testMapStructure(self):
structure1 = (((1, 2), 3), 4, (5, 6))
structure2 = (((7, 8), 9), 10, (11, 12))
structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1)
nest.assert_same_structure(structure1, structure1_plus1)
self.assertAllEqual(
[2, 3, 4, 5, 6, 7],
nest.flatten(structure1_plus1))
structure1_plus_structure2 = nest.map_structure(
lambda x, y: x + y, structure1, structure2)
self.assertEqual(
(((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
structure1_plus_structure2)

self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))

self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))

# Empty structures
self.assertEqual((), nest.map_structure(lambda x: x + 1, ()))
self.assertEqual([], nest.map_structure(lambda x: x + 1, []))
self.assertEqual({}, nest.map_structure(lambda x: x + 1, {}))
self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1,
NestTest.EmptyNT()))

# This is checking actual equality of types, empty list != empty tuple
self.assertNotEqual((), nest.map_structure(lambda x: x + 1, []))

with self.assertRaisesRegexp(TypeError, "callable"):
nest.map_structure("bad", structure1_plus1)

with self.assertRaisesRegexp(ValueError, "at least one structure"):
nest.map_structure(lambda x: x)

with self.assertRaisesRegexp(ValueError, "same number of elements"):
nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5))

with self.assertRaisesRegexp(ValueError, "same nested structure"):
nest.map_structure(lambda x, y: None, 3, (3,))

with self.assertRaisesRegexp(TypeError, "same sequence type"):
nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5])

with self.assertRaisesRegexp(ValueError, "same nested structure"):
nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))

structure1_list = [[[1, 2], 3], 4, [5, 6]]
with self.assertRaisesRegexp(TypeError, "same sequence type"):
nest.map_structure(lambda x, y: None, structure1, structure1_list)

nest.map_structure(lambda x, y: None, structure1, structure1_list,
check_types=False)

with self.assertRaisesRegexp(ValueError, "same nested structure"):
nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)),
check_types=False)

with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
nest.map_structure(lambda x: None, structure1, foo="a")

with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")

ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name

@test_util.assert_no_new_pyobjects_executing_eagerly
def testMapStructureWithStrings(self):
inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz"))
inp_b = NestTest.ABTuple(a=2, b=(1, 3))
out = nest.map_structure(lambda string, repeats: string * repeats,
inp_a,
inp_b)
self.assertEqual("foofoo", out.a)
self.assertEqual("bar", out.b[0])
self.assertEqual("bazbazbaz", out.b[1])

nt = NestTest.ABTuple(a=("something", "something_else"),
b="yet another thing")
rev_nt = nest.map_structure(lambda x: x[::-1], nt)
# Check the output is the correct structure, and all strings are reversed.
nest.assert_same_structure(nt, rev_nt)
self.assertEqual(nt.a[0][::-1], rev_nt.a[0])
self.assertEqual(nt.a[1][::-1], rev_nt.a[1])
self.assertEqual(nt.b[::-1], rev_nt.b)

@test_util.run_deprecated_v1
def testMapStructureOverPlaceholders(self):
inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]),
array_ops.placeholder(dtypes.float32, shape=[3, 7]))
inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]),
array_ops.placeholder(dtypes.float32, shape=[3, 7]))

output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b)

nest.assert_same_structure(output, inp_a)
self.assertShapeEqual(np.zeros((3, 4)), output[0])
self.assertShapeEqual(np.zeros((3, 7)), output[1])

feed_dict = {
inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)),
inp_b: (np.random.randn(3, 4), np.random.randn(3, 7))
}

with self.cached_session() as sess:
output_np = sess.run(output, feed_dict=feed_dict)
self.assertAllClose(output_np[0],
feed_dict[inp_a][0] + feed_dict[inp_b][0])
self.assertAllClose(output_np[1],
feed_dict[inp_a][1] + feed_dict[inp_b][1])

def testAssertShallowStructure(self):
inp_ab = ["a", "b"]
inp_abc = ["a", "b", "c"]
expected_message = (
"The two structures don't have the same sequence length. Input "
"structure has length 2, while shallow structure has length 3.")
with self.assertRaisesRegexp(ValueError, expected_message):
nest.assert_shallow_structure(inp_abc, inp_ab)

inp_ab1 = [(1, 1), (2, 2)]
inp_ab2 = [[1, 1], [2, 2]]
expected_message = (
"The two structures don't have the same sequence type. Input structure "
"has type <(type|class) 'tuple'>, while shallow structure has type "
"<(type|class) 'list'>.")
with self.assertRaisesRegexp(TypeError, expected_message):
nest.assert_shallow_structure(inp_ab2, inp_ab1)
nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False)

inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}}
inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}}
expected_message = (
r"The two structures don't have the same keys. Input "
r"structure has keys \['c'\], while shallow structure has "
r"keys \['d'\].")

with self.assertRaisesRegexp(ValueError, expected_message):
nest.assert_shallow_structure(inp_ab2, inp_ab1)

inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))])
inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)])
nest.assert_shallow_structure(inp_ab, inp_ba)

# This assertion is expected to pass: two namedtuples with the same
# name and field names are considered to be identical.
inp_shallow = NestTest.SameNameab(1, 2)
inp_deep = NestTest.SameNameab2(1, [1, 2, 3])
nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False)
nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True)

def testFlattenUpTo(self):
# Shallow tree ends at scalar.
input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
shallow_tree = [[True, True], [False, True]]
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]])
self.assertEqual(flattened_shallow_tree, [True, True, False, True])

# Shallow tree ends at string.
input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
input_tree)
input_tree_flattened = nest.flatten(input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree,
[("a", 1), ("b", 2), ("c", 3), ("d", 4)])
self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])

# Make sure dicts are correctly flattened, yielding values, not keys.
input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree,
[1, {"c": 2}, 3, (4, 5)])

# Namedtuples.
ab_tuple = NestTest.ABTuple
input_tree = ab_tuple(a=[0, 1], b=2)
shallow_tree = ab_tuple(a=0, b=1)
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree,
[[0, 1], 2])

# Nested dicts, OrderedDicts and namedtuples.
input_tree = collections.OrderedDict(
[("a", ab_tuple(a=[0, {"b": 1}], b=2)),
("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})])
shallow_tree = input_tree
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})])
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree,
[ab_tuple(a=[0, {"b": 1}], b=2),
3,
collections.OrderedDict([("f", 4)])])
shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree,
[ab_tuple(a=[0, {"b": 1}], b=2),
{"d": 3, "e": collections.OrderedDict([("f", 4)])}])

## Shallow non-list edge-case.
# Using iterable elements.
input_tree = ["input_tree"]
shallow_tree = "shallow_tree"
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])

input_tree = ["input_tree_0", "input_tree_1"]
shallow_tree = "shallow_tree"
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])

# Using non-iterable elements.
input_tree = [0]
shallow_tree = 9
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])

input_tree = [0, 1]
shallow_tree = 9
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])

## Both non-list edge-case.
# Using iterable elements.
input_tree = "input_tree"
shallow_tree = "shallow_tree"
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])

# Using non-iterable elements.
input_tree = 0
shallow_tree = 0
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])

## Input non-list edge-case.
# Using iterable elements.
input_tree = "input_tree"
shallow_tree = ["shallow_tree"]
expected_message = ("If shallow structure is a sequence, input must also "
"be a sequence. Input has type: <(type|class) 'str'>.")
with self.assertRaisesRegexp(TypeError, expected_message):
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree, shallow_tree)

input_tree = "input_tree"
shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
with self.assertRaisesRegexp(TypeError, expected_message):
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree, shallow_tree)

# Using non-iterable elements.
input_tree = 0
shallow_tree = [9]
expected_message = ("If shallow structure is a sequence, input must also "
"be a sequence. Input has type: <(type|class) 'int'>.")
with self.assertRaisesRegexp(TypeError, expected_message):
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree, shallow_tree)

input_tree = 0
shallow_tree = [9, 8]
with self.assertRaisesRegexp(TypeError, expected_message):
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree, shallow_tree)

def testMapStructureUpTo(self):
# Named tuples.
ab_tuple = collections.namedtuple("ab_tuple", "a, b")
op_tuple = collections.namedtuple("op_tuple", "add, mul")
inp_val = ab_tuple(a=2, b=3)
inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
out = nest.map_structure_up_to(
inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops)
self.assertEqual(out.a, 6)
self.assertEqual(out.b, 15)

# Lists.
data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
name_list = ["evens", ["odds", "primes"]]
out = nest.map_structure_up_to(
name_list, lambda name, sec: "first_{}_{}".format(len(sec), name),
name_list, data_list)
self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]])

# Dicts.
inp_val = dict(a=2, b=3)
inp_ops = dict(a=dict(add=1, mul=2), b=dict(add=2, mul=3))
out = nest.map_structure_up_to(
inp_val,
lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
self.assertEqual(out["a"], 6)
self.assertEqual(out["b"], 15)

# Non-equal dicts.
inp_val = dict(a=2, b=3)
inp_ops = dict(a=dict(add=1, mul=2), c=dict(add=2, mul=3))
with self.assertRaisesRegexp(ValueError, "same keys"):
nest.map_structure_up_to(
inp_val,
lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)

# Dict+custom mapping.
inp_val = dict(a=2, b=3)
inp_ops = _CustomMapping(a=dict(add=1, mul=2), b=dict(add=2, mul=3))
out = nest.map_structure_up_to(
inp_val,
lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
self.assertEqual(out["a"], 6)
self.assertEqual(out["b"], 15)

# Non-equal dict/mapping.
inp_val = dict(a=2, b=3)
inp_ops = _CustomMapping(a=dict(add=1, mul=2), c=dict(add=2, mul=3))
with self.assertRaisesRegexp(ValueError, "same keys"):
nest.map_structure_up_to(
inp_val,
lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)

def testGetTraverseShallowStructure(self):
scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []]
scalar_traverse_r = nest.get_traverse_shallow_structure(
lambda s: not isinstance(s, tuple),
scalar_traverse_input)
self.assertEqual(scalar_traverse_r,
[True, True, False, [True, True], {"a": False}, []])
nest.assert_shallow_structure(scalar_traverse_r,
scalar_traverse_input)

structure_traverse_input = [(1, [2]), ([1], 2)]
structure_traverse_r = nest.get_traverse_shallow_structure(
lambda s: (True, False) if isinstance(s, tuple) else True,
structure_traverse_input)
self.assertEqual(structure_traverse_r,
[(True, False), ([True], False)])
nest.assert_shallow_structure(structure_traverse_r,
structure_traverse_input)

with self.assertRaisesRegexp(TypeError, "returned structure"):
nest.get_traverse_shallow_structure(lambda _: [True], 0)

with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"):
nest.get_traverse_shallow_structure(lambda _: 1, [1])

with self.assertRaisesRegexp(
TypeError, "didn't return a depth=1 structure of bools"):
nest.get_traverse_shallow_structure(lambda _: [1], [1])

def testYieldFlatStringPaths(self):
for inputs_expected in ({"inputs": [], "expected": []},
{"inputs": 3, "expected": [()]},
{"inputs": [3], "expected": [(0,)]},
{"inputs": {"a": 3}, "expected": [("a",)]},
{"inputs": {"a": {"b": 4}},
"expected": [("a", "b")]},
{"inputs": [{"a": 2}], "expected": [(0, "a")]},
{"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]},
{"inputs": [{"a": [(23, 42)]}],
"expected": [(0, "a", 0, 0), (0, "a", 0, 1)]},
{"inputs": [{"a": ([23], 42)}],
"expected": [(0, "a", 0, 0), (0, "a", 1)]},
{"inputs": {"a": {"a": 2}, "c": [[[4]]]},
"expected": [("a", "a"), ("c", 0, 0, 0)]},
{"inputs": {"0": [{"1": 23}]},
"expected": [("0", 0, "1")]}):
inputs = inputs_expected["inputs"]
expected = inputs_expected["expected"]
self.assertEqual(list(nest.yield_flat_paths(inputs)), expected)

def testFlattenWithStringPaths(self):
for inputs_expected in (
{"inputs": [], "expected": []},
{"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]},
{"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}):
inputs = inputs_expected["inputs"]
expected = inputs_expected["expected"]
self.assertEqual(
nest.flatten_with_joined_string_paths(inputs, separator="/"),
expected)

# Need a separate test for namedtuple as we can't declare tuple definitions
# in the @parameterized arguments.
def testFlattenNamedTuple(self):
# pylint: disable=invalid-name
Foo = collections.namedtuple("Foo", ["a", "b"])
Bar = collections.namedtuple("Bar", ["c", "d"])
# pylint: enable=invalid-name
test_cases = [
(Foo(a=3, b=Bar(c=23, d=42)),
[("a", 3), ("b/c", 23), ("b/d", 42)]),
(Foo(a=Bar(c=23, d=42), b=Bar(c=0, d="something")),
[("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]),
(Bar(c=42, d=43),
[("c", 42), ("d", 43)]),
(Bar(c=[42], d=43),
[("c/0", 42), ("d", 43)]),
]
for inputs, expected in test_cases:
self.assertEqual(
list(nest.flatten_with_joined_string_paths(inputs)), expected)

@parameterized.named_parameters(
("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))),
("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True,
{"a": ("a", 4), "b": ("b", 6)}),
("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))),
("nested",
{"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True,
{"a": [("a/0", 10), ("a/1", 12)],
"b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]}))
def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected):
def format_sum(path, *values):
return (path, sum(values))
result = nest.map_structure_with_paths(format_sum, s1, s2,
check_types=check_types)
self.assertEqual(expected, result)

@parameterized.named_parameters(
("tuples", (1, 2), (3, 4, 5), ValueError),
("dicts", {"a": 1}, {"b": 2}, ValueError),
("mixed", (1, 2), [3, 4], TypeError),
("nested",
{"a": [2, 3], "b": [1, 3]},
{"b": [5, 6, 7], "a": [8, 9]},
ValueError
))
def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type):
with self.assertRaises(error_type):
nest.map_structure_with_paths(lambda path, *s: 0, s1, s2)


class NestBenchmark(test.Benchmark):

def run_and_report(self, s1, s2, name):
burn_iter, test_iter = 100, 30000

for _ in xrange(burn_iter):
nest.assert_same_structure(s1, s2)

t0 = time.time()
for _ in xrange(test_iter):
nest.assert_same_structure(s1, s2)
t1 = time.time()

self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter,
name=name)

def benchmark_assert_structure(self):
s1 = (((1, 2), 3), 4, (5, 6))
s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
self.run_and_report(s1, s2, "assert_same_structure_6_elem")

s1 = (((1, 2), 3), 4, (5, 6)) * 10
s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10
self.run_and_report(s1, s2, "assert_same_structure_60_elem")


if __name__ == "__main__":
test.main()

+ 0
- 249
test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs View File

@@ -1,249 +0,0 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Linq;
using Tensorflow;
using Tensorflow.UnitTest;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.ops_test
{
/// <summary>
/// excerpt of tensorflow/python/framework/ops_test.py
/// </summary>
[TestClass]
public class ControlDependenciesTest : GraphModeTestBase
{
[TestMethod]
public void TestBasic()
{
var g = tf.Graph().as_default();
Tensor a = null, b = null, c = null, d = null, e = null;

a = constant_op.constant(1.0);
b = constant_op.constant(1.0);
tf_with(g.control_dependencies(new[] { a }), x =>
{
c = constant_op.constant(1.0);
d = array_ops.identity(b);
e = array_ops.identity(c);
});

Assert.IsTrue(Enumerable.SequenceEqual(c.op.control_inputs, new[] { a.op }));
Assert.IsTrue(Enumerable.SequenceEqual(d.op.control_inputs, new[] { a.op }));
// e should be dominated by c.
Assert.AreEqual(0, e.op.control_inputs.Length);
}

[Ignore("How to port the ConvertibleObj?")]
[TestMethod]
public void TestBasicWithConversion()
{
var g = tf.Graph().as_default();
// Note: _apply_op can be replaced by g.create_op
var a = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT });
// TODO: ConvertibleObj, see original source below
/*
def testBasicWithConversion(self):
g = ops.Graph()
a = _apply_op(g, "FloatOutput", [], [dtypes.float32])

class ConvertibleObj(object):

def _as_graph_element(self):
return a

with g.control_dependencies([ConvertibleObj()]):
c = _apply_op(g, "FloatOutput", [], [dtypes.float32])

self.assertEqual(c.op.control_inputs, [a.op])
*/
}

[TestMethod]
public void TestNested()
{
var g = tf.Graph().as_default();
var a_1 = constant_op.constant(1.0);
var a_2 = constant_op.constant(3.0);
var a_3 = constant_op.constant(4.0);
var a_4 = constant_op.constant(5.0);
Tensor b_1 = null, b_2 = null;
tf_with(g.control_dependencies(new[] { a_1, a_2, a_3, a_4 }), ctrl =>
{
b_1 = constant_op.constant(6.0);
});
tf_with(g.control_dependencies(new[] { a_1 }), ctrl1 =>
{
tf_with(g.control_dependencies(new[] { a_2 }), ctrl2 =>
{
tf_with(g.control_dependencies(new[] { a_3 }), ctrl3 =>
{
tf_with(g.control_dependencies(new[] { a_4 }), ctrl4 =>
{
b_2 = constant_op.constant(7.0);
});
});
});
});
//var z=tf.add(a_1, tf.multiply(b_2, b_1));
//with(g.control_dependencies(new[] {z}), ctrl =>
//{
// var z1 = tf.add(a_3, tf.multiply(a_4, a_2));
//});
//tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);
assertItemsEqual(b_1.op.control_inputs, new[] { a_1.op, a_2.op, a_3.op, a_4.op });
assertItemsEqual(b_2.op.control_inputs, b_1.op.control_inputs);
}

[TestMethod]
public void TestClear()
{
var g = tf.Graph().as_default();
var a_1 = constant_op.constant(1.0);
var a_2 = constant_op.constant(3.0);
var a_3 = constant_op.constant(4.0);
var a_4 = constant_op.constant(5.0);
Operation b_3_4 = null, b_3 = null, b_none = null, b_1 = null, b_1_2 = null, b_none2 = null;
tf_with(g.control_dependencies(new[] { a_1 }), ctrl1 =>
{
tf_with(g.control_dependencies(new[] { a_2 }), ctrl2 =>
{
tf_with(g.control_dependencies(null), ctrl3 =>
{
tf_with(g.control_dependencies(new[] { a_3 }), ctrl4 =>
{
tf_with(g.control_dependencies(new[] { a_4 }), ctrl5 =>
{
// deps [a_3, a_4]
b_3_4 = constant_op.constant(7.0);
});
// deps = [a_3]
b_3 = constant_op.constant(8.0);
});
// deps back to None
b_none = constant_op.constant(9.0);
});
// deps back to [a_1, a_2]
b_1_2 = constant_op.constant(10.0);
});
// deps back to [a_1]
b_1 = constant_op.constant(11.0);
tf_with(g.control_dependencies(null), ctrl6 =>
{
// deps are None again
b_none2 = constant_op.constant(12.0);
});
});
// Note assertItemsEqual(given, expected), expected and given parameters should be swapped below
assertItemsEqual(new[] { a_3.op, a_4.op }, b_3_4.op.control_inputs);
assertItemsEqual(new[] { a_3.op }, b_3.op.control_inputs);
assertItemsEqual(new object[0], b_none.op.control_inputs);
assertItemsEqual(new[] { a_1.op, a_2.op }, b_1_2.op.control_inputs);
assertItemsEqual(new[] { a_1.op }, b_1.op.control_inputs);
assertItemsEqual(new object[0], b_none2.op.control_inputs);
}

[TestMethod]
public void TestComplex()
{
var g = tf.Graph().as_default();
// Usage pattern:
// * Nodes a_i are constants defined at the outermost scope, and are used
// as control inputs for the ith nested scope.
// * Nodes b_i are defined as Mul(a_3, a_4) at each scope.
// * Nodes c_i are defined as Mul(a_1, b_1) at each scope.
// * Nodes d_i are defined as Mul(b_i, c_i) at each scope.
// * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1.
var a_1 = constant_op.constant(1.0);
var a_2 = constant_op.constant(2.0);
var a_3 = constant_op.constant(3.0);
var a_4 = constant_op.constant(4.0);
Operation b_1 = null, b_2 = null, b_3 = null, b_4 = null;
Operation c_1 = null, c_2 = null, c_3 = null, c_4 = null;
Operation d_1 = null, d_2 = null, d_3 = null, d_4 = null;
Operation e_1 = null, e_2 = null, e_3 = null, e_4 = null;
tf_with(g.control_dependencies(new[] { a_1 }), ctrl1 =>
{
b_1 = tf.multiply(a_3, a_4);
c_1 = tf.multiply(a_1, b_1.output);
d_1 = tf.multiply(b_1.output, c_1.output);
e_1 = constant_op.constant(5.0);
tf_with(g.control_dependencies(new[] { a_2 }), ctrl2 =>
{
b_2 = tf.multiply(a_3, a_4);
c_2 = tf.multiply(a_1, b_1.output);
d_2 = tf.multiply(b_2.output, c_2.output);
e_2 = tf.multiply(e_1.output, e_1.output);
tf_with(g.control_dependencies(new[] { a_3 }), ctrl3 =>
{
b_3 = tf.multiply(a_3, a_4);
c_3 = tf.multiply(a_1, b_1.output);
d_3 = tf.multiply(b_3.output, c_3.output);
e_3 = tf.multiply(e_2.output, e_2.output);
tf_with(g.control_dependencies(new[] { a_4 }), ctrl4 =>
{
b_4 = tf.multiply(a_3, a_4);
c_4 = tf.multiply(a_1, b_1.output);
d_4 = tf.multiply(b_4.output, c_4.output);
e_4 = tf.multiply(e_3.output, e_3.output);
});
});
});
});

// Note assertItemsEqual(given, expected), expected and given parameters should be swapped below
assertItemsEqual(new[] { a_1.op }, b_1.op.control_inputs);
assertItemsEqual(new[] { a_1.op, a_2.op }, b_2.op.control_inputs);
assertItemsEqual(new[] { a_1.op, a_2.op }, b_3.op.control_inputs);
assertItemsEqual(new[] { a_1.op, a_2.op }, b_4.op.control_inputs);

assertItemsEqual(new object[0], c_1.op.control_inputs);
assertItemsEqual(new[] { a_2.op }, c_2.op.control_inputs);
assertItemsEqual(new[] { a_2.op, a_3.op }, c_3.op.control_inputs);
assertItemsEqual(new[] { a_2.op, a_3.op, a_4.op }, c_4.op.control_inputs);

assertItemsEqual(new object[0], d_1.op.control_inputs);
assertItemsEqual(new object[0], d_2.op.control_inputs);
assertItemsEqual(new object[0], d_3.op.control_inputs);
assertItemsEqual(new object[0], d_4.op.control_inputs);

assertItemsEqual(new[] { a_1.op }, e_1.op.control_inputs);
assertItemsEqual(new[] { a_2.op }, e_2.op.control_inputs);
assertItemsEqual(new[] { a_3.op }, e_3.op.control_inputs);
assertItemsEqual(new[] { a_4.op }, e_4.op.control_inputs);
}

[Ignore("Don't know how to create an operation with two outputs")]
[TestMethod]
public void TestRepeatedDependency()
{
/*
def testRepeatedDependency(self):
g = ops.Graph()
a = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32])
a_0, a_1 = a.outputs
with g.control_dependencies([a_0]):
b = _apply_op(g, "FloatOutput", [], [dtypes.float32])
with g.control_dependencies([a_1]):
c = _apply_op(g, "FloatOutput", [], [dtypes.float32])

self.assertEqual(b.op.control_inputs, [a])
self.assertEqual(c.op.control_inputs, [a])
*/
}

[TestMethod]
public void TestNoControlDependencyWithDataDependency()
{
var g = tf.Graph().as_default();
Operation b = null;
var a = constant_op.constant(100.0);
tf_with(g.control_dependencies(new[] { a }), ctrl1 =>
{
b = array_ops.identity(a);
});
Assert.AreEqual(0, b.op.control_inputs.Length);
}

}
}

+ 0
- 222
test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs View File

@@ -1,222 +0,0 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Linq;
using Tensorflow;
using Tensorflow.Operations;
using Tensorflow.UnitTest;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.ops_test
{
/// <summary>
/// excerpt of tensorflow/python/framework/ops_test.py
/// # These cases test the private Graph._create_op_from_tf_operation
/// # method. Arguably we should only test the public APIs that depend on this
/// # method. However, this logic is complex and tricky, and it can be difficult to
/// # ascertain if we have adequate coverage (e.g. a graph may run successfully if
/// # the control flow context isn't set properly, but a more complicated use case
/// # that might not be obvious to test will fail). Thus we instead explicitly test
/// # the low-level behavior.
/// </summary>
[Ignore]
[TestClass]
public class CreateOpFromTfOperationTest : GraphModeTestBase
{

[TestMethod]
public void TestShape()
{
using (var g = tf.Graph().as_default())
{
var x = constant_op.constant(new[,] { { 1, 2, 3 }, { 4, 5, 6 } });
var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] { x }, new Operation[0]);
var op = g._create_op_from_tf_operation(c_op);

Assert.AreEqual("myop", op.name);
Assert.AreEqual("Identity", op.type);
Assert.AreEqual(1, len(op.outputs));
assertItemsEqual(new[] { 2, 3 }, op.outputs[0].shape);
}
}

[TestMethod]
public void TestUniqueName()
{
var graph = tf.Graph().as_default();
//var (c_op,op_desc) = ops._create_c_op(g, ops._NodeDef("Const", "myop"), new Tensor[0], new Operation[0]);
//var (c_op2, op_desc1) = ops._create_c_op(g, ops._NodeDef("Const", "myop_1"), new Tensor[0], new Operation[0]);
//var op = g._create_op_from_tf_operation(c_op);
//var op2 = g._create_op_from_tf_operation(c_op2);
var op = constant_op.constant(0, name: "myop").op;
var op2 = constant_op.constant(0, name: "myop_1").op;

// Create ops with same names as op1 and op2. We expect the new names to be
// uniquified.
var op3 = constant_op.constant(0, name: "myop").op;
var op4 = constant_op.constant(0, name: "myop_1").op;

self.assertEqual(op.name, "myop");
self.assertEqual(op2.name, "myop_1");
self.assertEqual(op3.name, "myop_2");
self.assertEqual(op4.name, "myop_1_1");
}

[Ignore("need tesnroflow expose UpdateEdge API")]
[TestMethod]
public void TestCond()
{
var g = tf.Graph().as_default();
var x = constant_op.constant(10);

var true_fn = new Func<Tensor>(() =>
{
var c_op = ops._create_c_op(g, ops._NodeDef("Identity", "cond/myop"), new[] { x }, new Operation[0]);
var new_ops = g._add_new_tf_operations();
self.assertEqual(len(new_ops), 1);
return x;
});

control_flow_ops.cond(x < 10, true_fn, () => x);

var op = g.get_operation_by_name("cond/myop");

//tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta.txt", as_text:true);
//tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);

self.assertIsNotNone(op);
self.assertEqual(op.name, "cond/myop");
self.assertEqual(op.type, "Identity");
//self.assertEqual(op.outputs, new object[0]);
var op_input = op.inputs[0].op;
self.assertEqual(op_input.type, "Switch");
self.assertEqual(op_input.inputs[0].name, x.name);
self.assertEqual(op.graph, g);
self.assertIsNotNone(op._get_control_flow_context());
var cond_text = op._get_control_flow_context() as ControlFlowContext;
self.assertEqual(cond_text.Name, "cond/cond_text");
}

[Ignore("Todo: Port")]
[TestMethod]
public void TestWhileLoop()
{
var graph = tf.Graph().as_default();
Operation x = null;
x = constant_op.constant(42);
var body = new Func<int, int>(i =>
{
ops._create_c_op(ops.get_default_graph(), ops._NodeDef("Identity", "myloop/myop"), new[] { x.output },
new Operation[0]);
var new_ops = graph._add_new_tf_operations();
self.assertEqual(len(new_ops), 1);
return i;
});
// TODO: port control_flow_ops.while_loop
//control_flow_ops.while_loop( i => i < 10, body, new int[]{0}, name = "myloop");
var op = graph.get_operation_by_name("myloop/myop");
self.assertIsNotNone(op);
self.assertEqual(op.name, "myloop/myop");
self.assertEqual(op.type, "Identity");
self.assertEqual(op.outputs.Length, 0);
var op_input = op.inputs[0].op;
self.assertEqual(op_input.type, "Enter");
self.assertItemsEqual(op_input.inputs.OfType<Operation>().ToArray(), new[] { x });
self.assertEqual(op.graph, graph);
self.assertIsNotNone(op._get_control_flow_context());
self.assertEqual(((ControlFlowContext)op._get_control_flow_context()).Name, "myloop/while_context");
/*
@test_util.run_v1_only("b/120545219")
def testWhileLoop(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()

def body(i):
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return i

control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")

op = g.get_operation_by_name("myloop/myop")
self.assertIsNotNone(op)
self.assertEqual(op.name, "myloop/myop")
self.assertEqual(op.type, "IntInput")
self.assertEqual(op.outputs, [])
op_input = op.inputs[0].op
self.assertEqual(op_input.type, "Enter")
self.assertEqual(list(op_input.inputs), [x])
self.assertEqual(op.graph, g)
# pylint: disable=protected-access
self.assertIsNotNone(op._get_control_flow_context())
self.assertEqual(op._get_control_flow_context().name,
"myloop/while_context")
# pylint: enable=protected-access
*/
}

[Ignore("Todo: Port")]
[TestMethod]
public void TestWhileLoopWithInternalControlDep()
{
/*
@test_util.run_v1_only("b/120545219")
def testWhileLoopWithInternalControlDep(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()

def body(i):
c = constant_op.constant(1.0, name="c")
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
with ops.control_dependencies([c]):
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return i

control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")

op = g.get_operation_by_name("myloop/myop")
self.assertIsNotNone(op)
c = g.get_operation_by_name("myloop/c")
self.assertIsNotNone(c)
# Internal control dep is preserved
self.assertEqual(op.control_inputs, [c])
*/
}

[Ignore("Todo: Port")]
[TestMethod]
public void TestWhileLoopWithExternalControlDep()
{
/*
@test_util.run_v1_only("b/120545219")
def testWhileLoopWithExternalControlDep(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
c = constant_op.constant(1.0)

def body(i):
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
with ops.control_dependencies([c]):
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return i

control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")

op = g.get_operation_by_name("myloop/myop")
self.assertIsNotNone(op)
# External control dep is removed and replaced with internal control dep
self.assertNotEqual(op.control_inputs[0], c.op)
self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context())
*/
}

}
}

+ 0
- 196
test/TensorFlowNET.UnitTest/ops_test/GraphTest.cs View File

@@ -1,196 +0,0 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
using Tensorflow.UnitTest;

namespace TensorFlowNET.UnitTest.ops_test
{
/// <summary>
/// excerpt of tensorflow/python/framework/ops_test.py
/// </summary>
[TestClass]
public class GraphTest : GraphModeTestBase
{
[TestInitialize]
public void SetUp()
{
ops.reset_default_graph();
}

[TestCleanup]
public void TearDown()
{
ops.reset_default_graph();
}

private void _AssertDefault(Graph expected)
{
Assert.AreSame(ops.get_default_graph(), expected);
}


[Ignore("Todo: Port")]
[TestMethod]
public void testResetDefaultGraphNesting()
{
/*
def testResetDefaultGraphNesting(self):
g0 = ops.Graph()
with self.assertRaises(AssertionError):
with g0.as_default():
ops.reset_default_graph()
*/
}

[Ignore("Todo: Port")]
[TestMethod]
public void testGraphContextManagerCancelsEager()
{
/*
def testGraphContextManagerCancelsEager(self):
with context.eager_mode():
with ops.Graph().as_default():
self.assertFalse(context.executing_eagerly())
*/
}


[Ignore("Todo: Port")]
[TestMethod]
public void testGraphContextManager()
{
/*
def testGraphContextManager(self):
g0 = ops.Graph()
with g0.as_default() as g1:
self.assertIs(g0, g1)
*/
}

[Ignore("Todo: Port")]
[TestMethod]
public void testDefaultGraph()
{
/*
def testDefaultGraph(self):
orig = ops.get_default_graph()
self._AssertDefault(orig)
g0 = ops.Graph()
self._AssertDefault(orig)
context_manager_0 = g0.as_default()
self._AssertDefault(orig)
with context_manager_0 as g0:
self._AssertDefault(g0)
with ops.Graph().as_default() as g1:
self._AssertDefault(g1)
self._AssertDefault(g0)
self._AssertDefault(orig)
*/
}

[Ignore("Todo: Port")]
[TestMethod]
public void testPreventFeeding()
{
/*
def testPreventFeeding(self):
g = ops.Graph()
a = constant_op.constant(2.0)
self.assertTrue(g.is_feedable(a))
g.prevent_feeding(a)
self.assertFalse(g.is_feedable(a))
*/
}


[Ignore("Todo: Port")]
[TestMethod]
public void testAsGraphElementConversions()
{
/*
def testAsGraphElementConversions(self):

class ConvertibleObj(object):

def _as_graph_element(self):
return "FloatOutput:0"

class NonConvertibleObj(object):

pass

g = ops.Graph()
a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
self.assertEqual(a, g.as_graph_element(ConvertibleObj()))
with self.assertRaises(TypeError):
g.as_graph_element(NonConvertibleObj())
*/
}

[Ignore("Todo: Port")]
[TestMethod]
public void testGarbageCollected()
{
/*
# Regression test against creating custom __del__ functions in classes
# involved in cyclic references, e.g. Graph and Operation. (Python won't gc
# cycles that require calling a __del__ method, because the __del__ method can
# theoretically increase the object's refcount to "save" it from gc, and any
# already-deleted objects in the cycle would have be to restored.)
def testGarbageCollected(self):
# Create a graph we can delete and a weak reference to monitor if it's gc'd
g = ops.Graph()
g_ref = weakref.ref(g)
# Create some ops
with g.as_default():
a = constant_op.constant(2.0)
b = constant_op.constant(3.0)
c = math_ops.add(a, b)
# Create a session we can delete
with session.Session(graph=g) as sess:
self.evaluate(c)
# Delete all references and trigger gc
del g
del a
del b
del c
del sess
gc.collect()
self.assertIsNone(g_ref())
*/
}

[Ignore("Todo: Port")]
[TestMethod]
public void testRunnableAfterInvalidShape()
{
/*
def testRunnableAfterInvalidShape(self):
with ops.Graph().as_default():
with self.assertRaises(ValueError):
math_ops.add([1, 2], [1, 2, 3])
a = constant_op.constant(1)
with session.Session() as sess:
self.evaluate(a)
*/
}

[Ignore("Todo: Port")]
[TestMethod]
public void testRunnableAfterInvalidShapeWithKernelLabelMap()
{
/*
def testRunnableAfterInvalidShapeWithKernelLabelMap(self):
g = ops.Graph()
with g.as_default():
with g._kernel_label_map({"KernelLabelRequired": "overload_1"}):
with self.assertRaises(ValueError):
test_ops.kernel_label_required(1)
a = constant_op.constant(1)
with session.Session() as sess:
self.evaluate(a)
*/
}


}
}

+ 0
- 3014
test/TensorFlowNET.UnitTest/ops_test/ops_test_r1.13.py
File diff suppressed because it is too large
View File


+ 0
- 26
test/TensorFlowNET.UnitTest/python/train_saver.py View File

@@ -1,26 +0,0 @@

import tensorflow as tf

# Create some variables.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)

inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
inc_v1.op.run()
dec_v2.op.run()
# Save the variables to disk.
save_path = saver.save(sess, "/tmp/model.ckpt")
print("Model saved in path: %s" % save_path)

Loading…
Cancel
Save