diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln
index 3173c1a9..ffc39b5a 100644
--- a/TensorFlow.NET.sln
+++ b/TensorFlow.NET.sln
@@ -17,6 +17,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Text", "src\Tens
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Recommenders", "src\TensorFlowNET.Recommenders\Tensorflow.Recommenders.csproj", "{F17AAECB-960A-4E18-A270-BAD776F0E55B}"
EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Native.UnitTest", "test\TensorFlowNET.Native.UnitTest\Tensorflow.Native.UnitTest.csproj", "{84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}"
+EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -107,8 +109,8 @@ Global
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x86.Build.0 = Release|Any CPU
{03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|Any CPU.Build.0 = Debug|Any CPU
- {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x64.ActiveCfg = Debug|Any CPU
- {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x64.Build.0 = Debug|Any CPU
+ {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x64.ActiveCfg = Debug|x64
+ {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x64.Build.0 = Debug|x64
{03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x86.ActiveCfg = Debug|Any CPU
{03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x86.Build.0 = Debug|Any CPU
{03F06299-3F4B-4449-A709-3A647657BC0C}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU
@@ -155,8 +157,8 @@ Global
{49D71826-C03D-4FA7-9BAC-22C1327E65CF}.Release|x86.Build.0 = Release|Any CPU
{1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|Any CPU.Build.0 = Debug|Any CPU
- {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x64.ActiveCfg = Debug|Any CPU
- {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x64.Build.0 = Debug|Any CPU
+ {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x64.ActiveCfg = Debug|x64
+ {1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x64.Build.0 = Debug|x64
{1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x86.ActiveCfg = Debug|Any CPU
{1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug|x86.Build.0 = Debug|Any CPU
{1AB8108D-4FFE-4A16-88E7-328EAF686370}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU
@@ -179,8 +181,8 @@ Global
{1AB8108D-4FFE-4A16-88E7-328EAF686370}.Release|x86.Build.0 = Release|Any CPU
{F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|Any CPU.Build.0 = Debug|Any CPU
- {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x64.ActiveCfg = Debug|Any CPU
- {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x64.Build.0 = Debug|Any CPU
+ {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x64.ActiveCfg = Debug|x64
+ {F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x64.Build.0 = Debug|x64
{F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x86.ActiveCfg = Debug|Any CPU
{F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug|x86.Build.0 = Debug|Any CPU
{F17AAECB-960A-4E18-A270-BAD776F0E55B}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU
@@ -201,6 +203,30 @@ Global
{F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|x64.Build.0 = Release|Any CPU
{F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|x86.ActiveCfg = Release|Any CPU
{F17AAECB-960A-4E18-A270-BAD776F0E55B}.Release|x86.Build.0 = Release|Any CPU
+ {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|x64.ActiveCfg = Debug|x64
+ {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|x64.Build.0 = Debug|x64
+ {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|x86.ActiveCfg = Debug|Any CPU
+ {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug|x86.Build.0 = Debug|Any CPU
+ {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU
+ {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU
+ {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU
+ {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|x64.Build.0 = Debug|Any CPU
+ {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|x86.ActiveCfg = Debug|Any CPU
+ {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Debug-Minimal|x86.Build.0 = Debug|Any CPU
+ {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|Any CPU.ActiveCfg = Debug|Any CPU
+ {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|Any CPU.Build.0 = Debug|Any CPU
+ {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|x64.ActiveCfg = Debug|Any CPU
+ {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|x64.Build.0 = Debug|Any CPU
+ {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|x86.ActiveCfg = Debug|Any CPU
+ {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Publish|x86.Build.0 = Debug|Any CPU
+ {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|Any CPU.Build.0 = Release|Any CPU
+ {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x64.ActiveCfg = Release|Any CPU
+ {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x64.Build.0 = Release|Any CPU
+ {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x86.ActiveCfg = Release|Any CPU
+ {84CA35F8-99FC-408E-8DF3-5AA175E5EFD3}.Release|x86.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
diff --git a/src/TensorFlowNET.Console/Tensorflow.Console.csproj b/src/TensorFlowNET.Console/Tensorflow.Console.csproj
index d2b64e80..11dda95f 100644
--- a/src/TensorFlowNET.Console/Tensorflow.Console.csproj
+++ b/src/TensorFlowNET.Console/Tensorflow.Console.csproj
@@ -5,6 +5,7 @@
netcoreapp3.1
Tensorflow
Tensorflow
+ AnyCPU;x64
diff --git a/src/TensorFlowNET.Core/APIs/tf.graph.cs b/src/TensorFlowNET.Core/APIs/tf.graph.cs
index 7c0e7585..eec7f7f8 100644
--- a/src/TensorFlowNET.Core/APIs/tf.graph.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.graph.cs
@@ -28,17 +28,10 @@ namespace Tensorflow
=> ops.reset_default_graph();
public Graph get_default_graph()
- {
- return ops.get_default_graph();
- }
+ => ops.get_default_graph();
- ///
- /// Equivalent to but does not create a new graph if it there is none.
- ///
public Graph peak_default_graph()
- {
- return ops.default_graph_stack.peak_controller();
- }
+ => ops.peak_default_graph();
///
/// Creates a new graph.
diff --git a/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs b/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs
index 3626e9df..1a5c00d2 100644
--- a/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs
+++ b/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs
@@ -37,19 +37,7 @@ namespace Tensorflow.Contexts
if (shouldRunInEager)
return eagerAction();
else
- {
- if (executing_eagerly())
- {
- graph_mode();
- var result = graphAction();
- restore_mode();
- return result;
- }
- else
- {
- return graphAction();
- }
- }
+ return graphAction();
}
// [DebuggerStepThrough]
diff --git a/src/TensorFlowNET.Core/Contexts/Context.cs b/src/TensorFlowNET.Core/Contexts/Context.cs
index 17e85306..43564fdb 100644
--- a/src/TensorFlowNET.Core/Contexts/Context.cs
+++ b/src/TensorFlowNET.Core/Contexts/Context.cs
@@ -80,9 +80,12 @@ namespace Tensorflow.Contexts
/// Checks whether the current thread has eager execution enabled.
///
///
- [DebuggerStepThrough]
+ // [DebuggerStepThrough]
public bool executing_eagerly()
{
+ if(context_switches.Count() == 0)
+ tf.enable_eager_execution();
+
return context_switches.Current().EagerMode;
}
@@ -103,11 +106,16 @@ namespace Tensorflow.Contexts
public void restore_mode()
{
context_switches.Pop();
+ tf.get_default_graph();
}
public void reset_context()
{
- c_api.TFE_ContextClearCaches(_handle);
+ ops.reset_uid();
+ ops.reset_default_graph();
+ context_switches.Clear();
+ if (_handle != null)
+ c_api.TFE_ContextClearCaches(_handle);
}
public void Dispose()
diff --git a/src/TensorFlowNET.Core/Contexts/ContextSwitchStack.cs b/src/TensorFlowNET.Core/Contexts/ContextSwitchStack.cs
index 84bc3889..27704b3e 100644
--- a/src/TensorFlowNET.Core/Contexts/ContextSwitchStack.cs
+++ b/src/TensorFlowNET.Core/Contexts/ContextSwitchStack.cs
@@ -40,11 +40,21 @@ namespace Tensorflow.Contexts
});
}
+ public void Clear()
+ {
+ stack.Clear();
+ }
+
public void Pop()
{
stack.Pop();
}
+ public int Count()
+ {
+ return stack.Count;
+ }
+
public ContextSwitch Current()
{
return stack.Peek();
diff --git a/src/TensorFlowNET.Core/Data/MapDataset.cs b/src/TensorFlowNET.Core/Data/MapDataset.cs
index 09513f94..50176cdc 100644
--- a/src/TensorFlowNET.Core/Data/MapDataset.cs
+++ b/src/TensorFlowNET.Core/Data/MapDataset.cs
@@ -15,11 +15,13 @@ namespace Tensorflow
bool preserve_cardinality = false,
bool use_legacy_function = false) : base(input_dataset)
{
- using var func = new ConcreteFunction($"autograph_{map_func.Method.Name}");
+ using var func = new ConcreteFunction($"{map_func.Method.Name}_{Guid.NewGuid()}");
+ func.Enter();
var input = tf.placeholder(input_dataset.element_spec[0].dtype);
var output = map_func(input);
func.ToGraph(input, output);
-
+ func.Exit();
+
structure = func.OutputStructure;
variant_tensor = ops.map_dataset(input_dataset.variant_tensor,
diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
index 45fa3420..b1d932b6 100644
--- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
+++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
@@ -34,7 +34,6 @@ namespace Tensorflow.Functions
public ConcreteFunction(string name)
{
func_graph = new FuncGraph(name);
- func_graph.as_default();
}
public ConcreteFunction(FuncGraph graph, Dictionary attrs = null)
@@ -46,7 +45,7 @@ namespace Tensorflow.Functions
public ConcreteFunction(Func func, TF_DataType dtype)
{
- string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";
+ string func_name = $"{func.Method.Name}_{Guid.NewGuid()}";
// IntPtr func_handle;
using var graph = new FuncGraph(func_name);
@@ -59,11 +58,12 @@ namespace Tensorflow.Functions
new[] { input },
new[] { output },
null);
+ graph.Exit();
}
public ConcreteFunction(Func func, TF_DataType dtype)
{
- string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";
+ string func_name = $"{func.Method.Name}_{Guid.NewGuid()}";
// IntPtr func_handle;
using var graph = new FuncGraph(func_name);
@@ -79,12 +79,13 @@ namespace Tensorflow.Functions
new[] { input },
new[] { output.variant_tensor },
null);
+ graph.Exit();
}
public ConcreteFunction(Func func,
TF_DataType[] dtypes, TensorShape[] shapes)
{
- string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";
+ string func_name = $"{func.Method.Name}_{Guid.NewGuid()}";
// IntPtr func_handle;
using var graph = new FuncGraph(func_name);
@@ -103,6 +104,7 @@ namespace Tensorflow.Functions
new[] { input1, input2, input3 },
new[] { outputs.Item1, outputs.Item2 },
null);
+ graph.Exit();
}
public void ToGraph(Tensors inputs, Tensors outputs)
@@ -112,10 +114,19 @@ namespace Tensorflow.Functions
inputs,
outputs,
null);
-
OutputStructure = outputs.Select(x => x.ToTensorSpec()).ToArray();
}
+ public void Enter()
+ {
+ func_graph.as_default();
+ }
+
+ public void Exit()
+ {
+ func_graph.Exit();
+ }
+
public Tensors Invoke(Tensors inputs)
{
var forward_backward = SelectForwardAndBackwardFunctions(inputs, 1, tf.Context.executing_eagerly());
diff --git a/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs b/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs
index e0896253..f615f6a4 100644
--- a/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs
+++ b/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs
@@ -26,7 +26,6 @@ namespace Tensorflow.Functions
var output_names = new string[0];
_func_graph = new FuncGraph(graph, name, attrs);
- _func_graph.as_default();
_func_graph.ToGraph(operations, inputs, outputs, output_names);
}
diff --git a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs
index 5dd1a8ae..13c57e86 100644
--- a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs
+++ b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs
@@ -84,7 +84,7 @@ namespace Tensorflow.Functions
}
var gradients_wrt_outputs = new List();
- var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{ops.uid()}");
+ var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}");
backwards_graph.as_default();
foreach (var output in trainable_outputs)
gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape));
@@ -101,6 +101,7 @@ namespace Tensorflow.Functions
if (!_func_graph.Outputs.Contains(capture))
_func_graph.Outputs.Add(capture);
}
+ backwards_graph.Exit();
var forward_function_name = $"{_FORWARD_PREFIX}_{ops.uid()}";
var backward_function_attr = new Dictionary();
diff --git a/src/TensorFlowNET.Core/Graphs/AutoGraph.cs b/src/TensorFlowNET.Core/Graphs/AutoGraph.cs
index 901cbd6f..78b19a6f 100644
--- a/src/TensorFlowNET.Core/Graphs/AutoGraph.cs
+++ b/src/TensorFlowNET.Core/Graphs/AutoGraph.cs
@@ -8,7 +8,7 @@ namespace Tensorflow.Graphs
{
public Func to_graph(Func func)
{
- string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";
+ string func_name = $"{func.Method.Name}_{Guid.NewGuid()}";
// IntPtr func_handle;
using (var graph = new FuncGraph(func_name))
@@ -22,6 +22,7 @@ namespace Tensorflow.Graphs
new[] { input },
new[] { output },
null);
+ graph.Exit();
}
@@ -39,7 +40,7 @@ namespace Tensorflow.Graphs
public Func to_graph(Func func)
{
- string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";
+ string func_name = $"{func.Method.Name}_{Guid.NewGuid()}";
// IntPtr func_handle;
using (var graph = new FuncGraph(func_name))
@@ -54,6 +55,7 @@ namespace Tensorflow.Graphs
new[] { input1, input2 },
new[] { output },
null);
+ graph.Exit();
}
return (Tensor a, Tensor b) =>
diff --git a/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs b/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs
index 3d04b8a3..010eb345 100644
--- a/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs
+++ b/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs
@@ -18,7 +18,7 @@ namespace Tensorflow.Graphs
public override void OnEntry(MethodExecutionArgs args)
{
- func_name = $"autograph_{args.Instance.GetHashCode()}.{args.Method.Name}";
+ func_name = $"{args.Method.Name}_{Guid.NewGuid()}";
if (functions.ContainsKey(func_name))
{
@@ -34,6 +34,7 @@ namespace Tensorflow.Graphs
// make function as an Operation by autograph
// need to restore mode when exits
function = new ConcreteFunction(func_name);
+ function.Enter();
// convert to Tensors
if (args.Arguments[0] is Tensors inputs)
@@ -68,6 +69,8 @@ namespace Tensorflow.Graphs
}
else
function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), args.ReturnValue as Tensor);
+
+ function.Exit();
// cache function.
function.ReturnType = args.ReturnValue.GetType();
diff --git a/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs b/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs
index 5407dcfe..0eee3cdb 100644
--- a/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs
+++ b/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs
@@ -25,63 +25,43 @@ namespace Tensorflow
///
public class DefaultGraphStack
{
- private readonly List _stack = new List();
+ private readonly Stack _stack = new Stack();
+ Graph _global_default_graph;
- public void set_controller(Graph @default)
+ public Graph get_default()
{
- if (!_stack.Exists(x => x.Graph == @default))
- _stack.Add(new StackModel { Graph = @default, IsDefault = true });
+ if (_stack.Count > 0)
+ return _stack.Peek();
+ else if (_global_default_graph != null)
+ return _global_default_graph;
+ else
+ _global_default_graph = new Graph();
- foreach (var s in _stack)
- s.IsDefault = s.Graph == @default;
+ return _global_default_graph;
}
- public Graph get_controller()
+ public Graph get_controller(Graph g)
{
- if (_stack.Count == 0 || _stack.Count(x => x.IsDefault) == 0)
- _stack.Add(new StackModel { Graph = tf.Graph(), IsDefault = true });
- for (var i = _stack.Count - 1; i >= 0; i--)
- {
- var x = _stack[i];
- if (x.IsDefault)
- return x.Graph;
- }
-
- throw new TensorflowException("Unable to find a default graph");
+ _stack.Push(g);
+ return g;
}
public Graph peak_controller()
{
- if (_stack.Count == 0 || _stack.Count(x => x.IsDefault) == 0)
+ if (_stack.Count == 0)
return null;
- for (var i = _stack.Count - 1; i >= 0; i--)
- {
- var x = _stack[i];
- if (x.IsDefault)
- return x.Graph;
- }
-
- return null;
+ return _stack.Peek();
}
- public bool remove(Graph g)
+ public void pop()
{
- if (_stack.Count == 0)
- return false;
-
- var sm = _stack.Find(model => model.Graph == g);
- return sm != null && _stack.Remove(sm);
+ _stack.Pop();
}
public void reset()
{
_stack.Clear();
- }
-
- private class StackModel
- {
- public Graph Graph { get; set; }
- public bool IsDefault { get; set; }
+ _global_default_graph = null;
}
}
}
\ No newline at end of file
diff --git a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs
index 37f03267..797f81ad 100644
--- a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs
+++ b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs
@@ -94,8 +94,6 @@ namespace Tensorflow.Graphs
// mark_as_return
Outputs = outputs;// .Select(x => array_ops.identity(x)).ToArray();
- tf.Context.restore_mode();
-
return func_handle;
}
@@ -247,9 +245,10 @@ namespace Tensorflow.Graphs
return this;
}
- protected override void DisposeManagedResources()
+ public override void Exit()
{
- base.DisposeManagedResources();
+ tf.Context.restore_mode();
+ ops.pop_graph();
}
}
}
diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs
index 81cb08c4..e3f14336 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.cs
@@ -146,6 +146,7 @@ namespace Tensorflow
///
/// Returns a context manager that makes this `Graph` the default graph.
+ /// Must call Exit() to pop graph
///
///
public virtual Graph as_default()
@@ -487,7 +488,7 @@ namespace Tensorflow
protected override void DisposeManagedResources()
{
- ops.default_graph_stack.remove(this);
+
}
protected override void DisposeUnmanagedResources(IntPtr handle)
@@ -529,6 +530,11 @@ namespace Tensorflow
return new TensorShape(dims.Select(x => (int)x).ToArray());
}
+ public virtual void Exit()
+ {
+ ops.pop_graph();
+ }
+
string debugString = string.Empty;
public override string ToString()
{
diff --git a/src/TensorFlowNET.Core/Graphs/SubGraphUtility.cs b/src/TensorFlowNET.Core/Graphs/SubGraphUtility.cs
index 383fe8af..7c186f94 100644
--- a/src/TensorFlowNET.Core/Graphs/SubGraphUtility.cs
+++ b/src/TensorFlowNET.Core/Graphs/SubGraphUtility.cs
@@ -95,7 +95,7 @@ namespace Tensorflow.Graphs
_copy_non_source(op, graph, op_map, base_graph);
}
- tf.Context.restore_mode();
+ graph.Exit();
return op_map;
}
diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
index c02de5e8..1b2dcd24 100644
--- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
+++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
@@ -30,7 +30,7 @@ namespace Tensorflow
public Operation _apply_op_helper(string op_type_name, string name = null, Dictionary keywords = null)
{
- var g = ops.get_default_graph();
+ var g = ops._get_graph_from_inputs(keywords == null ? new object[0] : keywords.Values.ToArray());
var op_def = g.GetOpDef(op_type_name);
// Default name if not specified.
@@ -59,7 +59,8 @@ namespace Tensorflow
var input_types = new List();
object values = null;
- return tf_with(ops.name_scope(name), scope =>
+ g.as_default();
+ var ret_op = tf_with(ops.name_scope(name), scope =>
{
var inferred_from = new Dictionary();
var base_types = new List();
@@ -249,6 +250,8 @@ namespace Tensorflow
return op;
});
+ g.Exit();
+ return ret_op;
}
private void _MaybeColocateWith(ITensorOrOperation[] inputs)
diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs
index 725a50b6..18dc4426 100644
--- a/src/TensorFlowNET.Core/ops.cs
+++ b/src/TensorFlowNET.Core/ops.cs
@@ -78,6 +78,21 @@ namespace Tensorflow
return get_default_graph().get_collection_ref(key);
}
+ public static Graph _get_graph_from_inputs(params object[] op_input_list)
+ {
+ var current_default_graph = get_default_graph();
+ if (current_default_graph.building_function)
+ return current_default_graph;
+
+ Graph graph = null;
+ foreach (var op_input in op_input_list)
+ {
+ if (op_input is Tensor op_input_tensor)
+ graph = graph ?? op_input_tensor.graph;
+ }
+ return graph ?? current_default_graph;
+ }
+
public static Graph _get_graph_from_inputs(Tensors op_input_list)
=> _get_graph_from_inputs(op_input_list: op_input_list, graph: null);
@@ -337,6 +352,11 @@ namespace Tensorflow
return Interlocked.Increment(ref uid_number);
}
+ public static void reset_uid()
+ {
+ uid_number = -1;
+ }
+
public static void colocate_with(bool ignore_existing = false)
{
_colocate_with_for_gradient(null, null, ignore_existing);
diff --git a/src/TensorFlowNET.Core/ops.threading.cs b/src/TensorFlowNET.Core/ops.threading.cs
index e436cae0..f52dbcae 100644
--- a/src/TensorFlowNET.Core/ops.threading.cs
+++ b/src/TensorFlowNET.Core/ops.threading.cs
@@ -118,16 +118,10 @@ namespace Tensorflow
///
///
public static Graph get_default_graph()
- {
- //return _default_graph_stack.get_default()
- return default_graph_stack.get_controller();
- }
+ => default_graph_stack.get_default();
- public static Graph set_default_graph(Graph graph)
- {
- default_graph_stack.set_controller(graph);
- return default_graph_stack.get_controller();
- }
+ public static Graph set_default_graph(Graph g)
+ => default_graph_stack.get_controller(g);
///
/// Clears the default graph stack and resets the global default graph.
@@ -147,5 +141,11 @@ namespace Tensorflow
// "exit the nesting and create a new graph.");
default_graph_stack.reset();
}
+
+ public static Graph peak_default_graph()
+ => default_graph_stack.peak_controller();
+
+ public static void pop_graph()
+ => default_graph_stack.pop();
}
}
\ No newline at end of file
diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs
index e24294fc..4c25f70a 100644
--- a/src/TensorFlowNET.Keras/BackendImpl.cs
+++ b/src/TensorFlowNET.Keras/BackendImpl.cs
@@ -115,10 +115,10 @@ namespace Tensorflow.Keras
public void reset_uids() => PER_GRAPH_LAYER_NAME_UIDS = new Dictionary>();
public void clear_session()
{
- ops.reset_default_graph();
+ tf.Context.reset_context();
reset_uids();
ops.set_default_session(tf.Session(ops.get_default_graph()));
- var phase = tf.placeholder_with_default(false, new int[] { }, name: "keras_learning_phase");
+ // var phase = tf.placeholder_with_default(false, new int[] { }, name: "keras_learning_phase");
_GRAPH_LEARNING_PHASES = new Dictionary();
_GRAPH_LEARNING_PHASES[tf.get_default_graph()] = 0;
}
@@ -185,7 +185,7 @@ namespace Tensorflow.Keras
return tensor_util.constant_value(outputs);
var source_graph = outputs.graph;
- using var exec_graph = _scratch_graph();
+ var exec_graph = _scratch_graph();
var global_graph = get_graph();
if (source_graph == global_graph && exec_graph != global_graph)
{
diff --git a/src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs b/src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs
index 25a7cd7a..699bec4c 100644
--- a/src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs
+++ b/src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs
@@ -46,7 +46,7 @@ namespace Tensorflow.Keras.Engine
_set_mask_metadata(inputs, outputs, null);
});
- tf.Context.restore_mode();
+ graph.Exit();
return outputs;
}
diff --git a/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs b/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs
index 54af955a..4e1e0e36 100644
--- a/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs
+++ b/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs
@@ -81,8 +81,9 @@ namespace Tensorflow.Keras.Layers
sparse: args.Sparse,
ragged: args.Ragged);
+ graph.Exit();
+
isPlaceholder = true;
- tf.Context.restore_mode();
}
// Create an input node to add to self.outbound_node
diff --git a/src/TensorFlowNET.Recommenders/Tensorflow.Recommenders.csproj b/src/TensorFlowNET.Recommenders/Tensorflow.Recommenders.csproj
index ba47de36..8169ab66 100644
--- a/src/TensorFlowNET.Recommenders/Tensorflow.Recommenders.csproj
+++ b/src/TensorFlowNET.Recommenders/Tensorflow.Recommenders.csproj
@@ -5,6 +5,7 @@
0.0.1
TensorFlow Recommenders is a library for building recommender system models using TensorFlow.
LICENSE
+ AnyCPU;x64
diff --git a/src/TensorFlowNET.Text/Tensorflow.Text.csproj b/src/TensorFlowNET.Text/Tensorflow.Text.csproj
index e687524c..1ce9c875 100644
--- a/src/TensorFlowNET.Text/Tensorflow.Text.csproj
+++ b/src/TensorFlowNET.Text/Tensorflow.Text.csproj
@@ -7,12 +7,17 @@
true
0.0.1
LICENSE
+ AnyCPU;x64
DEBUG;TRACE
+
+ DEBUG;TRACE
+
+
True
diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/CApiAttributesTestcs.cs b/test/TensorFlowNET.Native.UnitTest/CApiAttributesTestcs.cs
similarity index 98%
rename from test/TensorFlowNET.UnitTest/NativeAPI/CApiAttributesTestcs.cs
rename to test/TensorFlowNET.Native.UnitTest/CApiAttributesTestcs.cs
index 0a18f63f..a22d02e0 100644
--- a/test/TensorFlowNET.UnitTest/NativeAPI/CApiAttributesTestcs.cs
+++ b/test/TensorFlowNET.Native.UnitTest/CApiAttributesTestcs.cs
@@ -1,8 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
-using Tensorflow;
-namespace TensorFlowNET.UnitTest.NativeAPI
+namespace Tensorflow.Native.UnitTest
{
///
/// tensorflow\c\c_api_test.cc
diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/CApiColocationTest.cs b/test/TensorFlowNET.Native.UnitTest/CApiColocationTest.cs
similarity index 98%
rename from test/TensorFlowNET.UnitTest/NativeAPI/CApiColocationTest.cs
rename to test/TensorFlowNET.Native.UnitTest/CApiColocationTest.cs
index 9d892199..09b88e13 100644
--- a/test/TensorFlowNET.UnitTest/NativeAPI/CApiColocationTest.cs
+++ b/test/TensorFlowNET.Native.UnitTest/CApiColocationTest.cs
@@ -1,9 +1,8 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Runtime.InteropServices;
-using Tensorflow;
-namespace TensorFlowNET.UnitTest.NativeAPI
+namespace Tensorflow.Native.UnitTest
{
///
/// tensorflow\c\c_api_test.cc
diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/CApiFunctionTest.cs b/test/TensorFlowNET.Native.UnitTest/CApiFunctionTest.cs
similarity index 99%
rename from test/TensorFlowNET.UnitTest/NativeAPI/CApiFunctionTest.cs
rename to test/TensorFlowNET.Native.UnitTest/CApiFunctionTest.cs
index 42d2a22c..570c6ae1 100644
--- a/test/TensorFlowNET.UnitTest/NativeAPI/CApiFunctionTest.cs
+++ b/test/TensorFlowNET.Native.UnitTest/CApiFunctionTest.cs
@@ -2,10 +2,9 @@
using System;
using System.Collections.Generic;
using System.Linq;
-using Tensorflow;
-using static TensorFlowNET.UnitTest.c_test_util;
+using static Tensorflow.Native.UnitTest.c_test_util;
-namespace TensorFlowNET.UnitTest.NativeAPI
+namespace Tensorflow.Native.UnitTest
{
///
/// tensorflow\c\c_api_function_test.cc
diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/CApiGradientsTest.cs b/test/TensorFlowNET.Native.UnitTest/CApiGradientsTest.cs
similarity index 99%
rename from test/TensorFlowNET.UnitTest/NativeAPI/CApiGradientsTest.cs
rename to test/TensorFlowNET.Native.UnitTest/CApiGradientsTest.cs
index 77698f99..6db90aef 100644
--- a/test/TensorFlowNET.UnitTest/NativeAPI/CApiGradientsTest.cs
+++ b/test/TensorFlowNET.Native.UnitTest/CApiGradientsTest.cs
@@ -1,11 +1,9 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using System;
-using Tensorflow;
using Tensorflow.Util;
-using Buffer = Tensorflow.Buffer;
-namespace TensorFlowNET.UnitTest.NativeAPI
+namespace Tensorflow.Native.UnitTest
{
///
/// tensorflow\c\c_api_test.cc
diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs b/test/TensorFlowNET.Native.UnitTest/CApiTest.cs
similarity index 98%
rename from test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs
rename to test/TensorFlowNET.Native.UnitTest/CApiTest.cs
index c9366b5f..e8a9486f 100644
--- a/test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs
+++ b/test/TensorFlowNET.Native.UnitTest/CApiTest.cs
@@ -1,13 +1,11 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
-using Tensorflow;
using Tensorflow.Device;
using Tensorflow.Eager;
-using Tensorflow.UnitTest;
-namespace TensorFlowNET.UnitTest
+namespace Tensorflow.Native.UnitTest
{
- public class CApiTest : GraphModeTestBase
+ public class CApiTest
{
protected static readonly TF_Code TF_OK = TF_Code.TF_OK;
protected static readonly TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT;
diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/CSession.cs b/test/TensorFlowNET.Native.UnitTest/CSession.cs
similarity index 98%
rename from test/TensorFlowNET.UnitTest/NativeAPI/CSession.cs
rename to test/TensorFlowNET.Native.UnitTest/CSession.cs
index 651bc853..c973e1b3 100644
--- a/test/TensorFlowNET.UnitTest/NativeAPI/CSession.cs
+++ b/test/TensorFlowNET.Native.UnitTest/CSession.cs
@@ -1,10 +1,9 @@
using System;
using System.Collections.Generic;
using System.Linq;
-using Tensorflow;
using Tensorflow.Util;
-namespace TensorFlowNET.UnitTest
+namespace Tensorflow.Native.UnitTest
{
///
/// tensorflow\c\c_test_util.cc
diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.Context.cs
similarity index 95%
rename from test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs
rename to test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.Context.cs
index ea737e3a..7628bbc2 100644
--- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Context.cs
+++ b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.Context.cs
@@ -1,9 +1,8 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
-using Tensorflow;
using Tensorflow.Device;
using Tensorflow.Eager;
-namespace TensorFlowNET.UnitTest.NativeAPI
+namespace Tensorflow.Native.UnitTest.Eager
{
public partial class CApiEagerTest
{
diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.Execute_MatMul_CPU.cs
similarity index 97%
rename from test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs
rename to test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.Execute_MatMul_CPU.cs
index 441f9d27..e8c6844a 100644
--- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Execute_MatMul_CPU.cs
+++ b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.Execute_MatMul_CPU.cs
@@ -1,10 +1,9 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
-using Tensorflow;
using Tensorflow.Eager;
using static Tensorflow.Binding;
-namespace TensorFlowNET.UnitTest.NativeAPI
+namespace Tensorflow.Native.UnitTest.Eager
{
public partial class CApiEagerTest
{
diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs
similarity index 97%
rename from test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs
rename to test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs
index 1a46a0fa..ce5a287f 100644
--- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs
+++ b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs
@@ -1,8 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
-using Tensorflow;
using Tensorflow.Eager;
-namespace TensorFlowNET.UnitTest.NativeAPI
+namespace Tensorflow.Native.UnitTest.Eager
{
public partial class CApiEagerTest
{
diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs
similarity index 97%
rename from test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs
rename to test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs
index 372534ee..ad878115 100644
--- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs
+++ b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.OpInferMixedTypeInputListAttrs.cs
@@ -1,8 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
-using Tensorflow;
using Tensorflow.Eager;
-namespace TensorFlowNET.UnitTest.NativeAPI
+namespace Tensorflow.Native.UnitTest.Eager
{
public partial class CApiEagerTest
{
diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandle.cs b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.TensorHandle.cs
similarity index 93%
rename from test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandle.cs
rename to test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.TensorHandle.cs
index 279b89cf..8f0c3b40 100644
--- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandle.cs
+++ b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.TensorHandle.cs
@@ -1,8 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
-using Tensorflow;
using static Tensorflow.Binding;
-namespace TensorFlowNET.UnitTest.NativeAPI
+namespace Tensorflow.Native.UnitTest.Eager
{
public partial class CApiEagerTest
{
diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.TensorHandleDevices.cs
similarity index 98%
rename from test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs
rename to test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.TensorHandleDevices.cs
index c16dca9b..9fc8f95e 100644
--- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.TensorHandleDevices.cs
+++ b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.TensorHandleDevices.cs
@@ -1,8 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
-using Tensorflow;
using Tensorflow.Eager;
-namespace TensorFlowNET.UnitTest.NativeAPI
+namespace Tensorflow.Native.UnitTest.Eager
{
public partial class CApiEagerTest
{
diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.Variables.cs
similarity index 97%
rename from test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs
rename to test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.Variables.cs
index cb32dee0..e6a091dc 100644
--- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.Variables.cs
+++ b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.Variables.cs
@@ -1,9 +1,8 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
-using Tensorflow;
using Tensorflow.Eager;
using static Tensorflow.Binding;
-namespace TensorFlowNET.UnitTest.NativeAPI
+namespace Tensorflow.Native.UnitTest.Eager
{
public partial class CApiEagerTest
{
diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.cs
similarity index 99%
rename from test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs
rename to test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.cs
index 769a8eb5..a9dec9b1 100644
--- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/CApi.Eager.cs
+++ b/test/TensorFlowNET.Native.UnitTest/Eager/CApi.Eager.cs
@@ -1,10 +1,9 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
-using Tensorflow;
using Tensorflow.Eager;
using static Tensorflow.Binding;
-namespace TensorFlowNET.UnitTest.NativeAPI
+namespace Tensorflow.Native.UnitTest.Eager
{
///
/// tensorflow\c\eager\c_api_test.cc
diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs b/test/TensorFlowNET.Native.UnitTest/Eager/GradientEagerTest.cs
similarity index 96%
rename from test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs
rename to test/TensorFlowNET.Native.UnitTest/Eager/GradientEagerTest.cs
index 66c35c1b..b4286bb7 100644
--- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs
+++ b/test/TensorFlowNET.Native.UnitTest/Eager/GradientEagerTest.cs
@@ -1,13 +1,12 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Linq;
-using Tensorflow;
using static Tensorflow.Binding;
-namespace TensorFlowNET.UnitTest.Gradient
+namespace Tensorflow.Native.UnitTest.Eager
{
[TestClass]
- public class GradientEagerTest : PythonTest
+ public class GradientEagerTest
{
[TestMethod]
public void ConstantSquare()
diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/GraphBuildTest.cs b/test/TensorFlowNET.Native.UnitTest/GraphBuildTest.cs
similarity index 94%
rename from test/TensorFlowNET.UnitTest/NativeAPI/GraphBuildTest.cs
rename to test/TensorFlowNET.Native.UnitTest/GraphBuildTest.cs
index 3d305a89..751dc355 100644
--- a/test/TensorFlowNET.UnitTest/NativeAPI/GraphBuildTest.cs
+++ b/test/TensorFlowNET.Native.UnitTest/GraphBuildTest.cs
@@ -1,8 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
-using Tensorflow;
using static Tensorflow.Binding;
-namespace TensorFlowNET.UnitTest.NativeAPI
+namespace Tensorflow.Native.UnitTest
{
[TestClass]
public class GraphBuildTest : CApiTest
diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/GraphTest.cs b/test/TensorFlowNET.Native.UnitTest/GraphTest.cs
similarity index 99%
rename from test/TensorFlowNET.UnitTest/NativeAPI/GraphTest.cs
rename to test/TensorFlowNET.Native.UnitTest/GraphTest.cs
index 4b75becc..e40fd5c8 100644
--- a/test/TensorFlowNET.UnitTest/NativeAPI/GraphTest.cs
+++ b/test/TensorFlowNET.Native.UnitTest/GraphTest.cs
@@ -1,10 +1,8 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
-using Tensorflow;
using static Tensorflow.Binding;
-using Buffer = Tensorflow.Buffer;
-namespace TensorFlowNET.UnitTest.NativeAPI
+namespace Tensorflow.Native.UnitTest
{
[TestClass]
public class GraphTest : CApiTest
diff --git a/test/TensorFlowNET.Native.UnitTest/Sessions/SessionTest.cs b/test/TensorFlowNET.Native.UnitTest/Sessions/SessionTest.cs
new file mode 100644
index 00000000..9ae2e379
--- /dev/null
+++ b/test/TensorFlowNET.Native.UnitTest/Sessions/SessionTest.cs
@@ -0,0 +1,74 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System;
+using System.Collections.Generic;
+
+namespace Tensorflow.Native.UnitTest.Sessions
+{
+ [TestClass, Ignore]
+ public class SessionTest : CApiTest
+ {
+ ///
+ /// tensorflow\c\c_api_test.cc
+ /// `TEST(CAPI, Session)`
+ ///
+ [TestMethod]
+ public void Session()
+ {
+ using var s = new Status();
+ using var graph = new Graph();
+
+ // Make a placeholder operation.
+ var feed = c_test_util.Placeholder(graph, s);
+
+ // Make a constant operation with the scalar "2".
+ var two = c_test_util.ScalarConst(2, graph, s);
+
+ // Add operation.
+ var add = c_test_util.Add(feed, two, graph, s);
+
+ var csession = new CSession(graph, s);
+ ASSERT_EQ(TF_Code.TF_OK, s.Code);
+
+ // Run the graph.
+ var inputs = new Dictionary();
+ inputs.Add(feed, new Tensor(3));
+ csession.SetInputs(inputs);
+
+ var outputs = new TF_Output[] { new TF_Output(add, 0) };
+ csession.SetOutputs(outputs);
+
+ csession.Run(s);
+ Tensor outTensor = csession.output_tensor(0);
+ EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
+ EXPECT_EQ(0, outTensor.NDims);
+ ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
+ var output_contents = outTensor.ToArray();
+ EXPECT_EQ(3 + 2, output_contents[0]);
+
+ // Add another operation to the graph.
+ var neg = c_test_util.Neg(add, graph, s);
+ ASSERT_EQ(TF_Code.TF_OK, s.Code);
+
+ // Run up to the new operation.
+ inputs = new Dictionary();
+ inputs.Add(feed, new Tensor(7));
+ csession.SetInputs(inputs);
+ outputs = new TF_Output[] { new TF_Output(neg, 0) };
+ csession.SetOutputs(outputs);
+ csession.Run(s);
+ ASSERT_EQ(TF_Code.TF_OK, s.Code);
+
+ outTensor = csession.output_tensor(0);
+ ASSERT_TRUE(outTensor != IntPtr.Zero);
+ EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
+ EXPECT_EQ(0, outTensor.NDims); // scalar
+ ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
+ output_contents = outTensor.ToArray();
+ EXPECT_EQ(-(7 + 2), output_contents[0]);
+
+ // Clean up
+ csession.CloseAndDelete(s);
+ ASSERT_EQ(TF_Code.TF_OK, s.Code);
+ }
+ }
+}
diff --git a/test/TensorFlowNET.Native.UnitTest/Tensorflow.Native.UnitTest.csproj b/test/TensorFlowNET.Native.UnitTest/Tensorflow.Native.UnitTest.csproj
new file mode 100644
index 00000000..05b28723
--- /dev/null
+++ b/test/TensorFlowNET.Native.UnitTest/Tensorflow.Native.UnitTest.csproj
@@ -0,0 +1,36 @@
+
+
+
+ netcoreapp3.1
+
+ false
+
+ AnyCPU;x64
+
+
+
+ true
+ DEBUG;TRACE
+ x64
+
+
+
+ true
+ DEBUG;TRACE
+ x64
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs b/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs
new file mode 100644
index 00000000..b7a208e4
--- /dev/null
+++ b/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs
@@ -0,0 +1,204 @@
+using FluentAssertions;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using NumSharp;
+using System;
+using System.Linq;
+using System.Runtime.InteropServices;
+using static Tensorflow.Binding;
+
+namespace Tensorflow.Native.UnitTest.Tensors
+{
+ [TestClass]
+ public class TensorTest : CApiTest
+ {
+ [TestMethod]
+ public unsafe void TensorFromFixed()
+ {
+ var array = new float[1000];
+ var span = new Span(array, 100, 500);
+ fixed (float* ptr = &MemoryMarshal.GetReference(span))
+ {
+ using (var t = new Tensor((IntPtr)ptr, new long[] { span.Length }, tf.float32, 4 * span.Length))
+ {
+ Assert.IsFalse(t.IsDisposed);
+ Assert.AreEqual(2000, (int)t.bytesize);
+ }
+ }
+
+ fixed (float* ptr = &array[0])
+ {
+ using (var t = new Tensor((IntPtr)ptr, new long[] { array.Length }, tf.float32, 4 * array.Length))
+ {
+ Assert.IsFalse(t.IsDisposed);
+ Assert.AreEqual(4000, (int)t.bytesize);
+ }
+ }
+ }
+
+ [TestMethod]
+ public void TensorFromArray()
+ {
+ var array = new float[1000];
+ using (var t = new Tensor(array, new long[] { array.Length }, tf.float32))
+ {
+ Assert.IsFalse(t.IsDisposed);
+ Assert.AreEqual(1000 * sizeof(float), (int)t.bytesize);
+ }
+
+ using (var t = new Tensor(new float[] { 1 }, new long[] { 1 }, tf.float32))
+ {
+ Assert.IsFalse(t.IsDisposed);
+ Assert.AreEqual(1 * sizeof(float), (int)t.bytesize);
+ }
+
+ using (var t = new Tensor(new float[] { 1 }, null, tf.float32))
+ {
+ Assert.IsFalse(t.IsDisposed);
+ Assert.AreEqual(1 * sizeof(float), (int)t.bytesize);
+ t.shape.Should().BeEmpty();
+ }
+ }
+
+ [TestMethod]
+ public void AllocateTensor()
+ {
+ ulong num_bytes = 6 * sizeof(float);
+ long[] dims = { 2, 3 };
+ Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes);
+ EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype);
+ EXPECT_EQ(2, t.NDims);
+ EXPECT_EQ((int)dims[0], t.shape[0]);
+ EXPECT_EQ(num_bytes, t.bytesize);
+ t.Dispose();
+ }
+
+
+ ///
+ /// Port from c_api_test.cc
+ /// `TEST(CAPI, MaybeMove)`
+ ///
+ [TestMethod, Ignore]
+ public void MaybeMove()
+ {
+ NDArray nd = np.array(2, 3);
+ Tensor t = new Tensor(nd);
+ Tensor o = t.MaybeMove();
+ ASSERT_TRUE(o == IntPtr.Zero); // It is unsafe to move memory TF might not own.
+ t.Dispose();
+ }
+
+ ///
+ /// Port from c_api_test.cc
+ /// `TEST(CAPI, Tensor)`
+ ///
+ [TestMethod]
+ public void Tensor()
+ {
+ var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3);
+
+ var tensor = new Tensor(nd);
+ var array = tensor.ToArray();
+
+ EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT);
+ EXPECT_EQ(tensor.rank, nd.ndim);
+ EXPECT_EQ((int)tensor.shape[0], nd.shape[0]);
+ EXPECT_EQ((int)tensor.shape[1], nd.shape[1]);
+ EXPECT_EQ(tensor.bytesize, (ulong)nd.size * sizeof(float));
+ Assert.IsTrue(Enumerable.SequenceEqual(nd.Data(), new float[] { 1, 2, 3, 4, 5, 6 }));
+ }
+
+ ///
+ /// Port from tensorflow\c\c_api_test.cc
+ /// `TEST(CAPI, SetShape)`
+ ///
+ [TestMethod]
+ public void SetShape()
+ {
+ var s = new Status();
+ var graph = new Graph().as_default();
+
+ var feed = c_test_util.Placeholder(graph, s);
+ var feed_out_0 = new TF_Output(feed, 0);
+
+ // Fetch the shape, it should be completely unknown.
+ int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle);
+
+ Assert.IsTrue(s.Code == TF_Code.TF_OK);
+ EXPECT_EQ(-1, num_dims);
+
+ // Set the shape to be unknown, expect no change.
+ c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s.Handle);
+ Assert.IsTrue(s.Code == TF_Code.TF_OK);
+ num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle);
+ EXPECT_EQ(-1, num_dims);
+
+ // Set the shape to be 2 x Unknown
+ long[] dims = { 2, -1 };
+ c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s.Handle);
+ Assert.IsTrue(s.Code == TF_Code.TF_OK);
+ num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle);
+ EXPECT_EQ(2, num_dims);
+
+ // Get the dimension vector appropriately.
+ var returned_dims = new long[dims.Length];
+ c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle);
+ Assert.IsTrue(s.Code == TF_Code.TF_OK);
+ Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims));
+
+ // Set to a new valid shape: [2, 3]
+ dims[1] = 3;
+ c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s.Handle);
+ Assert.IsTrue(s.Code == TF_Code.TF_OK);
+
+ // Fetch and see that the new value is returned.
+ c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle);
+ Assert.IsTrue(s.Code == TF_Code.TF_OK);
+ Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims));
+
+ // Try to set 'unknown' with unknown rank on the shape and see that
+ // it doesn't change.
+ c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s.Handle);
+ Assert.IsTrue(s.Code == TF_Code.TF_OK);
+ c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle);
+ Assert.IsTrue(s.Code == TF_Code.TF_OK);
+ EXPECT_EQ(2, num_dims);
+ EXPECT_EQ(2, (int)returned_dims[0]);
+ EXPECT_EQ(3, (int)returned_dims[1]);
+
+ // Try to set 'unknown' with same rank on the shape and see that
+ // it doesn't change.
+ dims[0] = -1;
+ dims[1] = -1;
+ c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s.Handle);
+ Assert.IsTrue(s.Code == TF_Code.TF_OK);
+ c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle);
+ Assert.IsTrue(s.Code == TF_Code.TF_OK);
+ EXPECT_EQ(2, num_dims);
+ EXPECT_EQ(2, (int)returned_dims[0]);
+ EXPECT_EQ(3, (int)returned_dims[1]);
+
+ // Try to fetch a shape with the wrong num_dims
+ c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s.Handle);
+ Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT);
+
+ // Try to set an invalid shape (cannot change 2x3 to a 2x5).
+ dims[1] = 5;
+ c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s.Handle);
+ Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT);
+
+ // Test for a scalar.
+ var three = c_test_util.ScalarConst(3, graph, s);
+ Assert.IsTrue(s.Code == TF_Code.TF_OK);
+ var three_out_0 = new TF_Output(three, 0);
+
+ num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s.Handle);
+ Assert.IsTrue(s.Code == TF_Code.TF_OK);
+ EXPECT_EQ(0, num_dims);
+ c_api.TF_GraphGetTensorShape(graph, feed_out_0, dims, num_dims, s.Handle);
+ Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT);
+
+ graph.Exit();
+ s.Dispose();
+ }
+ }
+}
\ No newline at end of file
diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/c_test_util.cs b/test/TensorFlowNET.Native.UnitTest/c_test_util.cs
similarity index 98%
rename from test/TensorFlowNET.UnitTest/NativeAPI/c_test_util.cs
rename to test/TensorFlowNET.Native.UnitTest/c_test_util.cs
index fea762aa..1a43b3e1 100644
--- a/test/TensorFlowNET.UnitTest/NativeAPI/c_test_util.cs
+++ b/test/TensorFlowNET.Native.UnitTest/c_test_util.cs
@@ -1,10 +1,8 @@
using System;
using System.Diagnostics.CodeAnalysis;
-using Tensorflow;
using Tensorflow.Util;
-using Buffer = Tensorflow.Buffer;
-namespace TensorFlowNET.UnitTest
+namespace Tensorflow.Native.UnitTest
{
///
/// Port from `tensorflow\c\c_test_util.cc`
diff --git a/test/TensorFlowNET.UnitTest/Basics/SessionTest.cs b/test/TensorFlowNET.UnitTest/Basics/SessionTest.cs
index 0bc38dd1..b95a7c6a 100644
--- a/test/TensorFlowNET.UnitTest/Basics/SessionTest.cs
+++ b/test/TensorFlowNET.UnitTest/Basics/SessionTest.cs
@@ -8,78 +8,11 @@ using Tensorflow;
using Tensorflow.Util;
using static Tensorflow.Binding;
-namespace TensorFlowNET.UnitTest.NativeAPI
+namespace TensorFlowNET.UnitTest
{
- [TestClass]
- public class SessionTest : CApiTest
+ [TestClass, Ignore]
+ public class SessionTest
{
- ///
- /// tensorflow\c\c_api_test.cc
- /// `TEST(CAPI, Session)`
- ///
- [TestMethod, Ignore]
- public void Session()
- {
- lock (Locks.ProcessWide)
- {
- var s = new Status();
- var graph = new Graph().as_default();
-
- // Make a placeholder operation.
- var feed = c_test_util.Placeholder(graph, s);
-
- // Make a constant operation with the scalar "2".
- var two = c_test_util.ScalarConst(2, graph, s);
-
- // Add operation.
- var add = c_test_util.Add(feed, two, graph, s);
-
- var csession = new CSession(graph, s);
- ASSERT_EQ(TF_Code.TF_OK, s.Code);
-
- // Run the graph.
- var inputs = new Dictionary();
- inputs.Add(feed, new Tensor(3));
- csession.SetInputs(inputs);
-
- var outputs = new TF_Output[] { new TF_Output(add, 0) };
- csession.SetOutputs(outputs);
-
- csession.Run(s);
- Tensor outTensor = csession.output_tensor(0);
- EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
- EXPECT_EQ(0, outTensor.NDims);
- ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
- var output_contents = outTensor.ToArray();
- EXPECT_EQ(3 + 2, output_contents[0]);
-
- // Add another operation to the graph.
- var neg = c_test_util.Neg(add, graph, s);
- ASSERT_EQ(TF_Code.TF_OK, s.Code);
-
- // Run up to the new operation.
- inputs = new Dictionary();
- inputs.Add(feed, new Tensor(7));
- csession.SetInputs(inputs);
- outputs = new TF_Output[] { new TF_Output(neg, 0) };
- csession.SetOutputs(outputs);
- csession.Run(s);
- ASSERT_EQ(TF_Code.TF_OK, s.Code);
-
- outTensor = csession.output_tensor(0);
- ASSERT_TRUE(outTensor != IntPtr.Zero);
- EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
- EXPECT_EQ(0, outTensor.NDims); // scalar
- ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
- output_contents = outTensor.ToArray();
- EXPECT_EQ(-(7 + 2), output_contents[0]);
-
- // Clean up
- csession.CloseAndDelete(s);
- ASSERT_EQ(TF_Code.TF_OK, s.Code);
- }
- }
-
[TestMethod]
public void EvalTensor()
{
diff --git a/test/TensorFlowNET.UnitTest/Basics/TensorTest.cs b/test/TensorFlowNET.UnitTest/Basics/TensorTest.cs
index a81074b3..c562f4e6 100644
--- a/test/TensorFlowNET.UnitTest/Basics/TensorTest.cs
+++ b/test/TensorFlowNET.UnitTest/Basics/TensorTest.cs
@@ -7,201 +7,11 @@ using System.Runtime.InteropServices;
using Tensorflow;
using static Tensorflow.Binding;
-namespace TensorFlowNET.UnitTest.NativeAPI
+namespace TensorFlowNET.UnitTest
{
- [TestClass]
- public class TensorTest : CApiTest
+ [TestClass, Ignore]
+ public class TensorTest
{
- [TestMethod]
- public unsafe void TensorFromFixed()
- {
- var array = new float[1000];
- var span = new Span(array, 100, 500);
- fixed (float* ptr = &MemoryMarshal.GetReference(span))
- {
- using (var t = new Tensor((IntPtr)ptr, new long[] { span.Length }, tf.float32, 4 * span.Length))
- {
- Assert.IsFalse(t.IsDisposed);
- Assert.AreEqual(2000, (int)t.bytesize);
- }
- }
-
- fixed (float* ptr = &array[0])
- {
- using (var t = new Tensor((IntPtr)ptr, new long[] { array.Length }, tf.float32, 4 * array.Length))
- {
- Assert.IsFalse(t.IsDisposed);
- Assert.AreEqual(4000, (int)t.bytesize);
- }
- }
- }
-
- [TestMethod]
- public unsafe void TensorFromArray()
- {
- var array = new float[1000];
- using (var t = new Tensor(array, new long[] { array.Length }, tf.float32))
- {
- Assert.IsFalse(t.IsDisposed);
- Assert.AreEqual(1000 * sizeof(float), (int)t.bytesize);
- }
-
- using (var t = new Tensor(new float[] { 1 }, new long[] { 1 }, tf.float32))
- {
- Assert.IsFalse(t.IsDisposed);
- Assert.AreEqual(1 * sizeof(float), (int)t.bytesize);
- }
-
- using (var t = new Tensor(new float[] { 1 }, null, tf.float32))
- {
- Assert.IsFalse(t.IsDisposed);
- Assert.AreEqual(1 * sizeof(float), (int)t.bytesize);
- t.shape.Should().BeEmpty();
- }
- }
-
- [TestMethod]
- public void AllocateTensor()
- {
- ulong num_bytes = 6 * sizeof(float);
- long[] dims = { 2, 3 };
- Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes);
- EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype);
- EXPECT_EQ(2, t.NDims);
- EXPECT_EQ((int)dims[0], t.shape[0]);
- EXPECT_EQ(num_bytes, t.bytesize);
- t.Dispose();
- }
-
-
- ///
- /// Port from c_api_test.cc
- /// `TEST(CAPI, MaybeMove)`
- ///
- [TestMethod, Ignore]
- public void MaybeMove()
- {
- NDArray nd = np.array(2, 3);
- Tensor t = new Tensor(nd);
- Tensor o = t.MaybeMove();
- ASSERT_TRUE(o == IntPtr.Zero); // It is unsafe to move memory TF might not own.
- t.Dispose();
- }
-
- ///
- /// Port from c_api_test.cc
- /// `TEST(CAPI, Tensor)`
- ///
- [TestMethod]
- public void Tensor()
- {
- var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3);
-
- var tensor = new Tensor(nd);
- var array = tensor.ToArray();
-
- EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT);
- EXPECT_EQ(tensor.rank, nd.ndim);
- EXPECT_EQ((int)tensor.shape[0], nd.shape[0]);
- EXPECT_EQ((int)tensor.shape[1], nd.shape[1]);
- EXPECT_EQ(tensor.bytesize, (ulong)nd.size * sizeof(float));
- Assert.IsTrue(Enumerable.SequenceEqual(nd.Data(), new float[] { 1, 2, 3, 4, 5, 6 }));
- }
-
- ///
- /// Port from tensorflow\c\c_api_test.cc
- /// `TEST(CAPI, SetShape)`
- ///
- [TestMethod]
- public void SetShape()
- {
- var s = new Status();
- var graph = new Graph().as_default();
-
- var feed = c_test_util.Placeholder(graph, s);
- var feed_out_0 = new TF_Output(feed, 0);
-
- // Fetch the shape, it should be completely unknown.
- int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle);
-
- Assert.IsTrue(s.Code == TF_Code.TF_OK);
- EXPECT_EQ(-1, num_dims);
-
- // Set the shape to be unknown, expect no change.
- c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s.Handle);
- Assert.IsTrue(s.Code == TF_Code.TF_OK);
- num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle);
- EXPECT_EQ(-1, num_dims);
-
- // Set the shape to be 2 x Unknown
- long[] dims = { 2, -1 };
- c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s.Handle);
- Assert.IsTrue(s.Code == TF_Code.TF_OK);
- num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle);
- EXPECT_EQ(2, num_dims);
-
- // Get the dimension vector appropriately.
- var returned_dims = new long[dims.Length];
- c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle);
- Assert.IsTrue(s.Code == TF_Code.TF_OK);
- Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims));
-
- // Set to a new valid shape: [2, 3]
- dims[1] = 3;
- c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s.Handle);
- Assert.IsTrue(s.Code == TF_Code.TF_OK);
-
- // Fetch and see that the new value is returned.
- c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle);
- Assert.IsTrue(s.Code == TF_Code.TF_OK);
- Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims));
-
- // Try to set 'unknown' with unknown rank on the shape and see that
- // it doesn't change.
- c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s.Handle);
- Assert.IsTrue(s.Code == TF_Code.TF_OK);
- c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle);
- Assert.IsTrue(s.Code == TF_Code.TF_OK);
- EXPECT_EQ(2, num_dims);
- EXPECT_EQ(2, (int)returned_dims[0]);
- EXPECT_EQ(3, (int)returned_dims[1]);
-
- // Try to set 'unknown' with same rank on the shape and see that
- // it doesn't change.
- dims[0] = -1;
- dims[1] = -1;
- c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s.Handle);
- Assert.IsTrue(s.Code == TF_Code.TF_OK);
- c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle);
- Assert.IsTrue(s.Code == TF_Code.TF_OK);
- EXPECT_EQ(2, num_dims);
- EXPECT_EQ(2, (int)returned_dims[0]);
- EXPECT_EQ(3, (int)returned_dims[1]);
-
- // Try to fetch a shape with the wrong num_dims
- c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s.Handle);
- Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT);
-
- // Try to set an invalid shape (cannot change 2x3 to a 2x5).
- dims[1] = 5;
- c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s.Handle);
- Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT);
-
- // Test for a scalar.
- var three = c_test_util.ScalarConst(3, graph, s);
- Assert.IsTrue(s.Code == TF_Code.TF_OK);
- var three_out_0 = new TF_Output(three, 0);
-
- num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s.Handle);
- Assert.IsTrue(s.Code == TF_Code.TF_OK);
- EXPECT_EQ(0, num_dims);
- c_api.TF_GraphGetTensorShape(graph, feed_out_0, null, num_dims, s.Handle);
- //Assert.IsTrue(s.Code == TF_Code.TF_OK);
-
- // graph.Dispose();
- s.Dispose();
- }
-
[TestMethod]
public void sparse_to_dense()
{
@@ -271,32 +81,5 @@ namespace TensorFlowNET.UnitTest.NativeAPI
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, masked.ToArray()));
}
}
-
- ///
- /// Creates a tensor from an image of 256x256x3 and resizes it to 100x100x3
- ///
- [TestMethod]
- public unsafe void tensor_resize()
- {
- tf.enable_eager_execution();
-
- var imageArray = new float[256 * 256 * 3];
-
- using var newSize = tf.convert_to_tensor(new int[] { 100, 100 });
-
- using (var t = tf.constant(imageArray, tf.float32, (1, 256, 256, 3)))
- {
- Assert.IsFalse(t.IsDisposed);
- Assert.AreEqual(256 * 256 * 3 * sizeof(float), (int)t.bytesize);
-
- using var resized = tf.image.resize_bilinear(t, newSize);
- EXPECT_EQ(resized.shape[0], 1);
- EXPECT_EQ(resized.shape[1], 100);
- EXPECT_EQ(resized.shape[2], 100);
- EXPECT_EQ(resized.shape[3], 3);
- }
-
- tf.compat.v1.disable_eager_execution();
- }
}
}
\ No newline at end of file
diff --git a/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs b/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs
index a83800e6..bd25736c 100644
--- a/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs
+++ b/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs
@@ -1,4 +1,5 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System;
using static Tensorflow.Binding;
namespace TensorFlowNET.UnitTest
@@ -10,11 +11,40 @@ namespace TensorFlowNET.UnitTest
{
if (!tf.executing_eagerly())
tf.enable_eager_execution();
+ tf.Context.ensure_initialized();
}
[TestCleanup]
public void TestClean()
{
}
+
+ public bool Equal(float[] f1, float[] f2)
+ {
+ bool ret = false;
+ var tolerance = .000001f;
+ for (var i = 0; i < f1.Length; i++)
+ {
+ ret = Math.Abs(f1[i] - f2[i]) <= tolerance;
+ if (!ret)
+ break;
+ }
+
+ return ret;
+ }
+
+ public bool Equal(double[] d1, double[] d2)
+ {
+ bool ret = false;
+ var tolerance = .000000000000001f;
+ for (var i = 0; i < d1.Length; i++)
+ {
+ ret = Math.Abs(d1[i] - d2[i]) <= tolerance;
+ if (!ret)
+ break;
+ }
+
+ return ret;
+ }
}
}
diff --git a/test/TensorFlowNET.UnitTest/EnforcedSinglethreadingTests.cs b/test/TensorFlowNET.UnitTest/EnforcedSinglethreadingTests.cs
index 0ac5f33e..c8634542 100644
--- a/test/TensorFlowNET.UnitTest/EnforcedSinglethreadingTests.cs
+++ b/test/TensorFlowNET.UnitTest/EnforcedSinglethreadingTests.cs
@@ -6,10 +6,10 @@ using System.Threading;
using Tensorflow;
using static Tensorflow.Binding;
-namespace TensorFlowNET.UnitTest.NativeAPI
+namespace TensorFlowNET.UnitTest
{
[TestClass]
- public class EnforcedSinglethreadingTests : CApiTest
+ public class EnforcedSinglethreadingTests
{
private static readonly object _singlethreadLocker = new object();
diff --git a/test/TensorFlowNET.UnitTest/GraphModeTestBase.cs b/test/TensorFlowNET.UnitTest/GraphModeTestBase.cs
index 67f4e3a7..8d008ddb 100644
--- a/test/TensorFlowNET.UnitTest/GraphModeTestBase.cs
+++ b/test/TensorFlowNET.UnitTest/GraphModeTestBase.cs
@@ -1,6 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using TensorFlowNET.UnitTest;
using static Tensorflow.Binding;
+using static Tensorflow.KerasApi;
namespace Tensorflow.UnitTest
{
@@ -15,6 +16,7 @@ namespace Tensorflow.UnitTest
[TestCleanup]
public void TestClean()
{
+ keras.backend.clear_session();
tf.enable_eager_execution();
}
}
diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/ActivationFunctionTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/ActivationFunctionTest.cs
index bac400d8..6f816d8f 100644
--- a/test/TensorFlowNET.UnitTest/ManagedAPI/ActivationFunctionTest.cs
+++ b/test/TensorFlowNET.UnitTest/ManagedAPI/ActivationFunctionTest.cs
@@ -5,7 +5,7 @@ using static Tensorflow.Binding;
namespace TensorFlowNET.UnitTest.nn_test
{
[TestClass]
- public class ActivationFunctionTest : TFNetApiTest
+ public class ActivationFunctionTest : EagerModeTestBase
{
// A constant vector of size 6
Tensor a = tf.constant(new float[] { 1.0f, -0.5f, 3.4f, -2.1f, 0.0f, -6.5f });
diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/BitwiseApiTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/BitwiseApiTest.cs
index 4ec4eb25..e57e5072 100644
--- a/test/TensorFlowNET.UnitTest/ManagedAPI/BitwiseApiTest.cs
+++ b/test/TensorFlowNET.UnitTest/ManagedAPI/BitwiseApiTest.cs
@@ -6,7 +6,7 @@ using static Tensorflow.Binding;
namespace TensorFlowNET.UnitTest.ManagedAPI
{
[TestClass]
- public class BitwiseApiTest : TFNetApiTest
+ public class BitwiseApiTest : EagerModeTestBase
{
[TestInitialize]
public void Init()
diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/FunctionApiTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/FunctionApiTest.cs
index e821d9e7..df00d588 100644
--- a/test/TensorFlowNET.UnitTest/ManagedAPI/FunctionApiTest.cs
+++ b/test/TensorFlowNET.UnitTest/ManagedAPI/FunctionApiTest.cs
@@ -7,7 +7,7 @@ using static Tensorflow.Binding;
namespace TensorFlowNET.UnitTest.ManagedAPI
{
[TestClass]
- public class FunctionApiTest : TFNetApiTest
+ public class FunctionApiTest : EagerModeTestBase
{
Tensor Min(Tensor a, Tensor b)
{
diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/MathApiTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/MathApiTest.cs
index 874ab053..26e89404 100644
--- a/test/TensorFlowNET.UnitTest/ManagedAPI/MathApiTest.cs
+++ b/test/TensorFlowNET.UnitTest/ManagedAPI/MathApiTest.cs
@@ -6,7 +6,7 @@ using static Tensorflow.Binding;
namespace TensorFlowNET.UnitTest.ManagedAPI
{
[TestClass]
- public class MathApiTest : TFNetApiTest
+ public class MathApiTest : EagerModeTestBase
{
// A constant vector of size 6
Tensor a = tf.constant(new float[] { 1.0f, -0.5f, 3.4f, -2.1f, 0.0f, -6.5f });
diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/TFNetApiTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/TFNetApiTest.cs
deleted file mode 100644
index 5fd20899..00000000
--- a/test/TensorFlowNET.UnitTest/ManagedAPI/TFNetApiTest.cs
+++ /dev/null
@@ -1,35 +0,0 @@
-using System;
-
-namespace TensorFlowNET.UnitTest
-{
- public class TFNetApiTest
- {
- public bool Equal(float[] f1, float[] f2)
- {
- bool ret = false;
- var tolerance = .000001f;
- for (var i = 0; i < f1.Length; i++)
- {
- ret = Math.Abs(f1[i] - f2[i]) <= tolerance;
- if (!ret)
- break;
- }
-
- return ret;
- }
-
- public bool Equal(double[] d1, double[] d2)
- {
- bool ret = false;
- var tolerance = .000000000000001f;
- for (var i = 0; i < d1.Length; i++)
- {
- ret = Math.Abs(d1[i] - d2[i]) <= tolerance;
- if (!ret)
- break;
- }
-
- return ret;
- }
- }
-}
diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/ZeroFractionTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/ZeroFractionTest.cs
deleted file mode 100644
index 349a6c39..00000000
--- a/test/TensorFlowNET.UnitTest/ManagedAPI/ZeroFractionTest.cs
+++ /dev/null
@@ -1,86 +0,0 @@
-using Microsoft.VisualStudio.TestTools.UnitTesting;
-using NumSharp;
-using System;
-using System.Linq;
-using Tensorflow;
-using Tensorflow.UnitTest;
-
-namespace TensorFlowNET.UnitTest.nn_test
-{
- [TestClass]
- public class ZeroFractionTest : GraphModeTestBase
- {
- protected double _ZeroFraction(NDArray x)
- {
- assert(x.shape);
- int total_elements = np.prod(x.shape);
-
- var eps = 1e-8;
- var nonzeros = x.Data().Count(d => Math.Abs(d) > eps);
- return 1.0 - nonzeros / (double)total_elements;
- }
-
- [Ignore("TODO implement nn_impl.zero_fraction")]
- [TestMethod]
- public void testZeroFraction()
- {
- var x_shape = new Shape(5, 17);
- var x_np = np.random.randint(0, 2, x_shape);
- //x_np.astype(np.float32);
- var y_np = this._ZeroFraction(x_np);
-
- var x_tf = constant_op.constant(x_np);
- x_tf.set_shape(x_shape);
- var y_tf = nn_impl.zero_fraction(x_tf);
- var y_tf_np = self.evaluate(y_tf);
-
- var eps = 1e-8;
- self.assertAllClose(y_tf_np, y_np, eps);
- }
-
- [Ignore("TODO implement nn_impl.zero_fraction")]
- [TestMethod]
- public void testZeroFractionEmpty()
- {
-
- var x = np.zeros(0);
- var y = self.evaluate(nn_impl.zero_fraction(new Tensor(x)));
- self.assertTrue(np.isnan(y));
- }
-
- [Ignore("TODO implement nn_impl.zero_fraction")]
- [TestMethod]
- public void testZeroFraction2_27Zeros()
- {
- var sparsity = nn_impl.zero_fraction(
- array_ops.zeros(new Shape((int)Math.Pow(2, 27 * 1.01)), dtypes.int8));
- self.assertAllClose(1.0, self.evaluate(sparsity));
- }
-
- [Ignore("TODO implement nn_impl.zero_fraction")]
- [TestMethod]
- public void testZeroFraction2_27Ones()
- {
- var sparsity = nn_impl.zero_fraction(
- array_ops.ones(new TensorShape((int)Math.Pow(2, 27 * 1.01)), dtypes.int8));
- self.assertAllClose(0.0, self.evaluate(sparsity));
- }
-
- [Ignore("TODO implement nn_impl.zero_fraction")]
- [TestMethod]
- public void testUnknownSize()
- {
- var value = array_ops.placeholder(dtype: dtypes.float32);
- var sparsity = nn_impl.zero_fraction(value);
- using (var sess = self.cached_session())
- {
- // TODO: make this compile
- //self.assertAllClose(
- // 0.25,
- // sess.run(sparsity, {value: [[0., 1.], [0.3, 2.]]}));
- }
- }
-
-
- }
-}
diff --git a/test/TensorFlowNET.UnitTest/MultithreadingTests.cs b/test/TensorFlowNET.UnitTest/MultithreadingTests.cs
index 91580d31..10f32ca2 100644
--- a/test/TensorFlowNET.UnitTest/MultithreadingTests.cs
+++ b/test/TensorFlowNET.UnitTest/MultithreadingTests.cs
@@ -73,6 +73,7 @@ namespace TensorFlowNET.UnitTest
tf.peak_default_graph().Should().BeNull();
var beforehand = tf.get_default_graph(); //this should create default automatically.
beforehand.graph_key.Should().NotContain("-0/", "Already created a graph in an other thread.");
+ beforehand.as_default();
tf.peak_default_graph().Should().NotBeNull();
using (var sess = tf.Session())
diff --git a/test/TensorFlowNET.UnitTest/Utilities/TestHelper.cs b/test/TensorFlowNET.UnitTest/Utilities/TestHelper.cs
index ed379f3c..d1cda728 100644
--- a/test/TensorFlowNET.UnitTest/Utilities/TestHelper.cs
+++ b/test/TensorFlowNET.UnitTest/Utilities/TestHelper.cs
@@ -1,4 +1,5 @@
-using System.IO;
+using System;
+using System.IO;
namespace TensorFlowNET.UnitTest
{
@@ -6,8 +7,16 @@ namespace TensorFlowNET.UnitTest
{
public static string GetFullPathFromDataDir(string fileName)
{
- var dir = Path.Combine(Directory.GetCurrentDirectory(), "..", "..", "..", "..", "..", "data");
- return Path.GetFullPath(Path.Combine(dir, fileName));
+ var dataDir = GetRootContentDir(Directory.GetCurrentDirectory());
+ return Path.Combine(dataDir, fileName);
+ }
+
+ static string GetRootContentDir(string dir)
+ {
+ var path = Path.GetFullPath(Path.Combine(dir, "data"));
+ if (Directory.Exists(path))
+ return path;
+ return GetRootContentDir(Path.GetFullPath(Path.Combine(dir, "..")));
}
}
}
diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/control_flow_ops_test.py b/test/TensorFlowNET.UnitTest/control_flow_ops_test/control_flow_ops_test.py
deleted file mode 100644
index f1dd4f52..00000000
--- a/test/TensorFlowNET.UnitTest/control_flow_ops_test/control_flow_ops_test.py
+++ /dev/null
@@ -1,1059 +0,0 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for control_flow_ops.py."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import collections
-import numpy as np
-
-from tensorflow.python import tf2
-from tensorflow.core.framework import graph_pb2
-from tensorflow.core.framework import node_def_pb2
-from tensorflow.python.eager import def_function
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.framework import test_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import embedding_ops
-from tensorflow.python.ops import gradients_impl
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import tensor_array_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables
-import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import
-from tensorflow.python.platform import googletest
-from tensorflow.python.training import momentum
-from tensorflow.python.util import nest
-
-
-TestTuple = collections.namedtuple("TestTuple", "a b")
-SingletonTestTuple = collections.namedtuple("SingletonTestTuple", "a")
-
-
-class GroupTestCase(test_util.TensorFlowTestCase):
-
- def _StripNode(self, nd):
- snode = node_def_pb2.NodeDef(name=nd.name, op=nd.op, input=nd.input)
- if nd.device:
- snode.device = nd.device
- return snode
-
- def _StripGraph(self, gd):
- """Copy gd keeping only, node.name, node.op, node.input, and node.device."""
- return graph_pb2.GraphDef(node=[self._StripNode(nd) for nd in gd.node])
-
- def testGroup_NoDevices(self):
- with ops.Graph().as_default() as g:
- a = constant_op.constant(0, name="a")
- b = constant_op.constant(0, name="b")
- c = constant_op.constant(0, name="c")
- control_flow_ops.group(a.op, b.op, c.op, name="root")
- gd = g.as_graph_def()
- self.assertProtoEquals("""
- node { name: "a" op: "Const"}
- node { name: "b" op: "Const"}
- node { name: "c" op: "Const"}
- node { name: "root" op: "NoOp" input: "^a" input: "^b" input: "^c" }
- """, self._StripGraph(gd))
-
- def testGroup_OneDevice(self):
- with ops.Graph().as_default() as g:
- with g.device("/task:0"):
- a = constant_op.constant(0, name="a")
- b = constant_op.constant(0, name="b")
- control_flow_ops.group(a.op, b.op, name="root")
- gd = g.as_graph_def()
- self.assertProtoEquals("""
- node { name: "a" op: "Const" device: "/task:0" }
- node { name: "b" op: "Const" device: "/task:0" }
- node { name: "root" op: "NoOp" input: "^a" input: "^b" device: "/task:0" }
- """, self._StripGraph(gd))
-
- def testGroup_MultiDevice(self):
- with ops.Graph().as_default() as g:
- with g.device("/task:0"):
- a = constant_op.constant(0, name="a")
- b = constant_op.constant(0, name="b")
- with g.device("/task:1"):
- c = constant_op.constant(0, name="c")
- d = constant_op.constant(0, name="d")
- with g.device("/task:2"):
- control_flow_ops.group(a.op, b.op, c.op, d.op, name="root")
- gd = g.as_graph_def()
- self.assertProtoEquals("""
- node { name: "a" op: "Const" device: "/task:0"}
- node { name: "b" op: "Const" device: "/task:0"}
- node { name: "c" op: "Const" device: "/task:1"}
- node { name: "d" op: "Const" device: "/task:1"}
- node { name: "root/NoOp" op: "NoOp" input: "^a" input: "^b"
- device: "/task:0" }
- node { name: "root/NoOp_1" op: "NoOp" input: "^c" input: "^d"
- device: "/task:1" }
- node { name: "root" op: "NoOp" input: "^root/NoOp" input: "^root/NoOp_1"
- device: "/task:2" }
- """, self._StripGraph(gd))
-
- def testPassingList(self):
- with ops.Graph().as_default() as g:
- a = constant_op.constant(0, name="a")
- b = constant_op.constant(0, name="b")
- control_flow_ops.group([a.op, b.op], name="root")
- gd = g.as_graph_def()
- self.assertProtoEquals("""
- node { name: "a" op: "Const"}
- node { name: "b" op: "Const"}
- node { name: "root" op: "NoOp" input: "^a" input: "^b" }
- """, self._StripGraph(gd))
-
- @test_util.run_deprecated_v1
- def testPassingNonTensors(self):
- with self.assertRaises(TypeError):
- control_flow_ops.group(1, 2)
-
-
-class ShapeTestCase(test_util.TensorFlowTestCase):
-
- def testShape(self):
- tensor = constant_op.constant([1.0, 2.0])
- self.assertEquals([2], tensor.get_shape())
- self.assertEquals([2],
- control_flow_ops.with_dependencies(
- [constant_op.constant(1.0)], tensor).get_shape())
-
-
-class WithDependenciesTestCase(test_util.TensorFlowTestCase):
-
- @test_util.run_deprecated_v1
- def testTupleDependencies(self):
- counter = variable_scope.get_variable(
- "my_counter", shape=[], initializer=init_ops.zeros_initializer())
- increment_counter = state_ops.assign_add(counter, 1)
- const_with_dep = control_flow_ops.with_dependencies(
- (increment_counter, constant_op.constant(42)),
- constant_op.constant(7))
-
- self.evaluate(variables.global_variables_initializer())
- self.assertEquals(0, self.evaluate(counter))
- self.assertEquals(7, self.evaluate(const_with_dep))
- self.assertEquals(1, self.evaluate(counter))
-
- @test_util.run_deprecated_v1
- def testListDependencies(self):
- counter = variable_scope.get_variable(
- "my_counter", shape=[], initializer=init_ops.zeros_initializer())
- increment_counter = state_ops.assign_add(counter, 1)
- const_with_dep = control_flow_ops.with_dependencies(
- [increment_counter, constant_op.constant(42)],
- constant_op.constant(7))
-
- self.evaluate(variables.global_variables_initializer())
- self.assertEquals(0, self.evaluate(counter))
- self.assertEquals(7, self.evaluate(const_with_dep))
- self.assertEquals(1, self.evaluate(counter))
-
-
-class SwitchTestCase(test_util.TensorFlowTestCase):
-
- @test_util.run_deprecated_v1
- def testIndexedSlicesWithDenseShape(self):
- with self.cached_session():
- data = ops.IndexedSlices(
- constant_op.constant([1, 2, 3]),
- constant_op.constant([0, 1]),
- dense_shape=constant_op.constant([3]))
- zero = constant_op.constant(0)
- one = constant_op.constant(1)
- less_op = math_ops.less(zero, one)
- _, switch_true = control_flow_ops.switch(data, less_op)
- self.assertAllEqual([1, 2, 3], switch_true.values.eval())
- self.assertAllEqual([0, 1], switch_true.indices.eval())
-
- @test_util.run_deprecated_v1
- def testIndexedSlicesGradient(self):
- embedding_matrix = variable_scope.get_variable(
- "embedding_matrix", [5, 5],
- initializer=init_ops.random_normal_initializer())
-
- def cond(it, _):
- return it < 5
-
- def body(it, cost):
- embedding = embedding_ops.embedding_lookup(embedding_matrix + 0.0, [0])
- cost += math_ops.reduce_sum(embedding)
- return it + 1, cost
-
- _, cost = control_flow_ops.while_loop(
- cond, body, [constant_op.constant(0),
- constant_op.constant(0.0)])
- optimizer = momentum.MomentumOptimizer(0.1, 0.9)
- train_op = optimizer.minimize(cost)
- with self.cached_session():
- self.evaluate(variables.global_variables_initializer())
- for _ in range(10):
- self.evaluate([train_op])
-
- def testResourceReadInLoop(self):
- embedding_matrix = variable_scope.get_variable(
- "embedding_matrix", initializer=[[2.0], [3.0]], use_resource=True)
-
- def cond(it, _):
- return it < 5
-
- def body(it, cost):
- embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
- cost += math_ops.reduce_sum(embedding)
- return it + 1, cost
-
- _, cost = control_flow_ops.while_loop(
- cond, body, [constant_op.constant(0),
- constant_op.constant(0.0)])
- with self.cached_session():
- self.evaluate(variables.global_variables_initializer())
- self.assertAllEqual(10.0, self.evaluate(cost))
-
- def doTestIndexedSlicesGradientInCondInWhileLoop(self, use_resource=False):
- embedding_matrix = variable_scope.get_variable(
- "embedding_matrix", [5, 5],
- initializer=init_ops.random_normal_initializer(),
- use_resource=use_resource)
-
- def cond(it, _):
- return it < 5
-
- def body(it, cost):
- embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
- cost = control_flow_ops.cond(
- math_ops.equal(it, 3), lambda: math_ops.square(cost),
- (lambda: cost + math_ops.reduce_sum(embedding)))
- return it + 1, cost
-
- _, cost = control_flow_ops.while_loop(
- cond, body, [constant_op.constant(0),
- constant_op.constant(0.0)])
-
- dynamic_grads = gradients_impl.gradients(cost, [embedding_matrix])[0]
- dynamic_grads = math_ops.segment_sum(dynamic_grads.values,
- dynamic_grads.indices)
-
- embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
- static = math_ops.square(
- math_ops.reduce_sum(embedding) + math_ops.reduce_sum(embedding) +
- math_ops.reduce_sum(embedding)) + math_ops.reduce_sum(embedding)
- static_grads = gradients_impl.gradients(static, [embedding_matrix])[0]
- static_grads = math_ops.segment_sum(static_grads.values,
- static_grads.indices)
-
- with self.cached_session():
- self.evaluate(variables.global_variables_initializer())
- self.assertAllEqual(*self.evaluate([static_grads, dynamic_grads]))
-
- def testIndexedSlicesGradientInCondInWhileLoop(self):
- self.doTestIndexedSlicesGradientInCondInWhileLoop(use_resource=False)
-
- def testIndexedSlicesGradientInCondInWhileLoopResource(self):
- self.doTestIndexedSlicesGradientInCondInWhileLoop(use_resource=True)
-
- @test_util.run_v1_only("b/120545219")
- def testIndexedSlicesWithShapeGradientInWhileLoop(self):
- for dtype in [dtypes.float32, dtypes.float64]:
- with self.cached_session() as sess:
- num_steps = 9
-
- inputs = array_ops.placeholder(dtype=dtype, shape=[num_steps])
- initial_outputs = tensor_array_ops.TensorArray(
- dtype=dtype, size=num_steps)
- initial_i = constant_op.constant(0, dtype=dtypes.int32)
-
- def cond(i, _):
- return i < num_steps # pylint: disable=cell-var-from-loop
-
- def body(i, outputs):
- x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop
- outputs = outputs.write(i, x)
- return i + 1, outputs
-
- _, outputs = control_flow_ops.while_loop(cond, body,
- [initial_i, initial_outputs])
-
- outputs = math_ops.reduce_sum(outputs.stack())
- r = gradients_impl.gradients([outputs], [inputs])[0]
- grad_wr_inputs = ops.convert_to_tensor(r)
- o, grad = sess.run([outputs, grad_wr_inputs],
- feed_dict={inputs: [4, 6, 0, 7, 0, 0, 1, 2, 0]})
- self.assertEquals(o, 20)
- self.assertAllEqual(grad, [1] * num_steps)
-
- @test_util.run_v1_only("b/120545219")
- def testIndexedSlicesWithDynamicShapeGradientInWhileLoop(self):
- for dtype in [dtypes.float32, dtypes.float64]:
- with self.cached_session() as sess:
- inputs = array_ops.placeholder(dtype=dtype)
- initial_outputs = tensor_array_ops.TensorArray(
- dtype=dtype, dynamic_size=True, size=1)
- initial_i = constant_op.constant(0, dtype=dtypes.int32)
-
- def cond(i, _):
- return i < array_ops.size(inputs) # pylint: disable=cell-var-from-loop
-
- def body(i, outputs):
- x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop
- outputs = outputs.write(i, x)
- return i + 1, outputs
-
- _, outputs = control_flow_ops.while_loop(cond, body,
- [initial_i, initial_outputs])
-
- outputs = math_ops.reduce_sum(outputs.stack())
- r = gradients_impl.gradients([outputs], [inputs])[0]
- grad_wr_inputs = ops.convert_to_tensor(r)
- o, grad = sess.run([outputs, grad_wr_inputs],
- feed_dict={inputs: [1, 3, 2]})
- self.assertEquals(o, 6)
- self.assertAllEqual(grad, [1] * 3)
-
- @test_util.run_deprecated_v1
- def testGradientThroughSingleBranchOutsideOfContext(self):
- x = constant_op.constant(2.)
- s = constant_op.constant(True)
- x_false, x_true = control_flow_ops.switch(x, s)
- grad_x_true = gradients_impl.gradients(x_true, x)[0]
- grad_x_false = gradients_impl.gradients(x_false, x)[0]
- self.assertEquals(self.evaluate(grad_x_true), 1.)
- self.assertEquals(self.evaluate(grad_x_false), 0.)
-
-
-class CondTest(test_util.TensorFlowTestCase):
-
- def testCondTrue(self):
- x = constant_op.constant(2)
- y = constant_op.constant(5)
- z = control_flow_ops.cond(
- math_ops.less(
- x,
- y), lambda: math_ops.multiply(x, 17), lambda: math_ops.add(y, 23))
- self.assertEquals(self.evaluate(z), 34)
-
- def testCondFalse(self):
- x = constant_op.constant(2)
- y = constant_op.constant(1)
- z = control_flow_ops.cond(
- math_ops.less(
- x,
- y), lambda: math_ops.multiply(x, 17), lambda: math_ops.add(y, 23))
- self.assertEquals(self.evaluate(z), 24)
-
- def testCondTrueLegacy(self):
- x = constant_op.constant(2)
- y = constant_op.constant(5)
- z = control_flow_ops.cond(
- math_ops.less(x, y),
- fn1=lambda: math_ops.multiply(x, 17),
- fn2=lambda: math_ops.add(y, 23))
- self.assertEquals(self.evaluate(z), 34)
-
- def testCondFalseLegacy(self):
- x = constant_op.constant(2)
- y = constant_op.constant(1)
- z = control_flow_ops.cond(
- math_ops.less(x, y),
- fn1=lambda: math_ops.multiply(x, 17),
- fn2=lambda: math_ops.add(y, 23))
- self.assertEquals(self.evaluate(z), 24)
-
- @test_util.run_deprecated_v1
- def testCondModifyBoolPred(self):
- # This test in particular used to fail only when running in GPU, hence
- # use_gpu=True.
- with test_util.use_gpu():
- bool_var = variable_scope.get_variable(
- "bool_var", dtype=dtypes.bool, initializer=True)
- cond_on_bool_var = control_flow_ops.cond(
- pred=bool_var,
- true_fn=lambda: state_ops.assign(bool_var, False),
- false_fn=lambda: True)
- self.evaluate(bool_var.initializer)
- self.assertEquals(self.evaluate(cond_on_bool_var), False)
- self.assertEquals(self.evaluate(cond_on_bool_var), True)
-
- def testCondMissingArg1(self):
- x = constant_op.constant(1)
- with self.assertRaises(TypeError):
- control_flow_ops.cond(True, false_fn=lambda: x)
-
- def testCondMissingArg2(self):
- x = constant_op.constant(1)
- with self.assertRaises(TypeError):
- control_flow_ops.cond(True, lambda: x)
-
- def testCondDuplicateArg1(self):
- x = constant_op.constant(1)
- with self.assertRaises(TypeError):
- control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x)
-
- def testCondDuplicateArg2(self):
- x = constant_op.constant(1)
- with self.assertRaises(TypeError):
- control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x)
-
-
-class ContextTest(test_util.TensorFlowTestCase):
-
- @test_util.run_deprecated_v1
- def testCondContext(self):
- with self.cached_session() as sess:
- x = constant_op.constant(2)
- y = constant_op.constant(5)
- control_flow_ops.cond(
- math_ops.less(x, y), lambda: math_ops.multiply(x, 17),
- lambda: math_ops.add(y, 23))
- for op in sess.graph.get_operations():
- c = op._get_control_flow_context()
- if c:
- self.assertProtoEquals(
- c.to_proto(),
- control_flow_ops.CondContext.from_proto(c.to_proto()).to_proto())
-
- def _testWhileContextHelper(self, maximum_iterations=None):
- with self.cached_session() as sess:
- i = constant_op.constant(0)
- c = lambda i: math_ops.less(i, 10)
- b = lambda i: math_ops.add(i, 1)
- control_flow_ops.while_loop(
- c, b, [i], maximum_iterations=maximum_iterations)
- for op in sess.graph.get_operations():
- control_flow_context = op._get_control_flow_context()
- if control_flow_context:
- self.assertProtoEquals(
- control_flow_context.to_proto(),
- control_flow_ops.WhileContext.from_proto(
- control_flow_context.to_proto()).to_proto())
-
- @test_util.run_deprecated_v1
- def testWhileContext(self):
- self._testWhileContextHelper()
-
- @test_util.run_deprecated_v1
- def testWhileContextWithMaximumIterations(self):
- self._testWhileContextHelper(maximum_iterations=10)
-
- @test_util.run_deprecated_v1
- def testControlContextImportScope(self):
- class NoABCControlFlowContext(control_flow_ops.ControlFlowContext):
- """A noop wrapper around `ControlFlowContext`.
-
- `ControlFlowContext` is an ABC and therefore cannot be instantiated.
- """
- # pylint: disable=useless-super-delegation
-
- def to_control_flow_context_def(self, context_def, export_scope=None):
- super(NoABCControlFlowContext, self).to_control_flow_context_def(
- context_def, export_scope)
-
- with self.cached_session():
- constant_op.constant(0, name="a")
- constant_op.constant(2, name="test_scope/a")
- b1 = constant_op.constant(1, name="b")
- b2 = constant_op.constant(3, name="test_scope/b")
-
- c = NoABCControlFlowContext()
- c._values = ["a", "b"]
- c._external_values = {"a": b1}
-
- c_with_scope = NoABCControlFlowContext(
- values_def=c._to_values_def(), import_scope="test_scope")
-
- # _values and _external_values should be have scope prepended.
- self.assertEquals(
- c_with_scope._values, set(["test_scope/a", "test_scope/b"]))
- self.assertEquals(
- c_with_scope._external_values, {"test_scope/a": b2})
-
- # Calling _to_proto() with export_scope should remove "test_scope".
- self.assertProtoEquals(
- c._to_values_def(),
- c_with_scope._to_values_def(export_scope="test_scope"))
-
-
-def _get_nested_shape(nested):
-
- def _get_shape(tensor):
- if isinstance(tensor, tensor_array_ops.TensorArray):
- return tensor_array_ops.TensorArray
- elif isinstance(tensor, ops.IndexedSlices):
- return tensor.dense_shape
- else:
- return tensor.get_shape()
-
- return nest.map_structure(_get_shape, nested)
-
-
-def _create_tensor_array(size, shape):
- ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=size,
- clear_after_read=False)
- for i in range(size):
- ta = ta.write(i, array_ops.zeros(shape))
- return ta
-
-
-def _raw_nested_shape(nested_shape):
-
- def _raw_shape(shape):
- if isinstance(shape, tensor_shape.TensorShape) and shape.ndims is not None:
- return [x.value for x in shape.dims]
- else:
- return None
-
- return nest.map_structure(_raw_shape, nested_shape)
-
-
-# TODO(yori): Add tests for indexed slices.
-class DataTypesTest(test_util.TensorFlowTestCase):
-
- def assertAllEqualNested(self, a, b):
- if isinstance(a, (list, tuple)):
- for entry_a, entry_b in zip(a, b):
- self.assertAllEqualNested(entry_a, entry_b)
- else:
- self.assertAllEqual(a, b)
-
- def _testShape(self, fn_true, fn_false, expected_shape,
- strict=False):
- condition = array_ops.placeholder(dtypes.bool)
- output_cond = control_flow_ops.cond(condition, fn_true, fn_false,
- strict=strict)
- self.assertEqual(
- _raw_nested_shape(_get_nested_shape(output_cond)),
- _raw_nested_shape(expected_shape))
-
- output_case = control_flow_ops.case([(condition, fn_true)], fn_false,
- strict=strict)
- self.assertEqual(
- _raw_nested_shape(_get_nested_shape(output_case)),
- _raw_nested_shape(expected_shape))
-
- def _testReturnValues(self, fn_true, fn_false, expected_value_true,
- expected_value_false, strict=False,
- check_cond=True, feed_dict=None):
- if feed_dict is None: feed_dict = {}
-
- condition = array_ops.placeholder(dtypes.bool)
- output_cond = control_flow_ops.cond(condition, fn_true, fn_false,
- strict=strict)
- output_case = control_flow_ops.case([(condition, fn_true)], fn_false,
- strict=strict)
-
- with self.cached_session() as sess:
- self.evaluate(variables.global_variables_initializer())
- true_feed_dict = {condition: True}
- true_feed_dict.update(feed_dict)
- result_cond, result_case = sess.run([output_cond, output_case],
- feed_dict=true_feed_dict)
- self.assertAllEqualNested(result_cond, expected_value_true)
- if check_cond:
- self.assertAllEqualNested(result_case, expected_value_true)
- false_feed_dict = {condition: False}
- false_feed_dict.update(feed_dict)
- result_cond, result_case = sess.run([output_cond, output_case],
- feed_dict=false_feed_dict)
- self.assertAllEqualNested(result_cond, expected_value_false)
- if check_cond:
- self.assertAllEqualNested(result_case, expected_value_false)
-
- @test_util.run_deprecated_v1
- def test_int(self):
- shape = tensor_shape.TensorShape([])
- fn_true = lambda: 1
- fn_false = lambda: 2
- self._testShape(fn_true, fn_false, shape)
- self._testReturnValues(fn_true, fn_false, 1, 2)
- self._testShape(fn_true, fn_false, shape, strict=True)
- self._testReturnValues(fn_true, fn_false, 1, 2, strict=True)
-
- @test_util.run_deprecated_v1
- def test_float(self):
- shape = tensor_shape.TensorShape([])
- fn_true = lambda: 1.0
- fn_false = lambda: 2.0
- self._testShape(fn_true, fn_false, shape)
- self._testReturnValues(fn_true, fn_false, 1.0, 2.0)
-
- @test_util.run_deprecated_v1
- def test_noop(self):
- shape = tensor_shape.TensorShape(None)
- self._testShape(control_flow_ops.no_op, control_flow_ops.no_op, shape)
- self._testReturnValues(control_flow_ops.no_op, control_flow_ops.no_op,
- True, False, check_cond=False)
-
- @test_util.run_deprecated_v1
- def test_string(self):
- shape = tensor_shape.TensorShape([])
- fn_true = lambda: "abc"
- fn_false = lambda: "xyz"
- self._testShape(fn_true, fn_false, shape)
- self._testReturnValues(fn_true, fn_false, b"abc", b"xyz")
-
- @test_util.run_deprecated_v1
- def test_variable(self):
- shape = tensor_shape.TensorShape([])
- fn_true = lambda: variables.Variable(3.0)
- fn_false = lambda: variables.Variable(4.0)
- self._testShape(fn_true, fn_false, shape)
- self._testReturnValues(fn_true, fn_false, 3.0, 4.0)
-
- @test_util.run_v1_only("b/120553181")
- def test_none(self):
- fn_none = lambda: None
- fn_tensor = lambda: constant_op.constant(1)
-
- with self.assertRaises(ValueError):
- control_flow_ops.cond(constant_op.constant(True), fn_none, fn_tensor)
-
- with self.assertRaises(ValueError):
- control_flow_ops.cond(constant_op.constant(True), fn_tensor, fn_none)
-
- @test_util.run_deprecated_v1
- def test_tensors(self):
-
- def _build_true_branch(dtype):
-
- def _build():
- return (array_ops.zeros([2, 2], dtype=dtype),
- array_ops.ones([3, 3], dtype=dtype))
-
- return _build
-
- def _build_false_branch(dtype):
-
- def _build():
- return (array_ops.ones([2, 2], dtype=dtype),
- array_ops.zeros([3, 3], dtype=dtype))
-
- return _build
-
- for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8):
- shape = (tensor_shape.TensorShape([2, 2]),
- tensor_shape.TensorShape([3, 3]))
- fn_true = _build_true_branch(dtype)
- fn_false = _build_false_branch(dtype)
- self._testShape(fn_true, fn_false, shape)
- self._testReturnValues(fn_true, fn_false,
- (np.zeros([2, 2]), np.ones([3, 3])),
- (np.ones([2, 2]), np.zeros([3, 3])))
-
- @test_util.run_deprecated_v1
- def test_tensors_unknown_shape(self):
-
- def _build_true_branch(dtype):
- tensor = array_ops.placeholder(dtype=dtype, shape=None)
-
- def _build():
- return tensor
-
- return _build, tensor
-
- def _build_false_branch(dtype):
- tensor = array_ops.placeholder(dtype=dtype, shape=None)
-
- def _build():
- return tensor
-
- return _build, tensor
-
- for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8):
- shape = tensor_shape.TensorShape(None)
- fn_true, true_tensor = _build_true_branch(dtype)
- fn_false, false_tensor = _build_false_branch(dtype)
- self._testShape(fn_true, fn_false, shape)
- self._testReturnValues(fn_true, fn_false,
- np.zeros([2, 2]), np.ones([2, 2]),
- feed_dict={true_tensor: np.zeros([2, 2]),
- false_tensor: np.ones([2, 2])})
-
- @test_util.run_deprecated_v1
- def test_sparse_tensors(self):
- shape = tensor_shape.TensorShape([None, None])
-
- def true_fn():
- return [sparse_tensor.SparseTensor(indices=[[0, 0], [1, 2]],
- values=[1, 2], dense_shape=[3, 4])]
-
- def false_fn():
- return [sparse_tensor.SparseTensor(indices=[[0, 0], [2, 1]],
- values=[3, 4], dense_shape=[3, 4])]
-
- value1 = sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 2]],
- values=[1, 2], dense_shape=[3, 4])
- value2 = sparse_tensor.SparseTensorValue(indices=[[0, 0], [2, 1]],
- values=[3, 4], dense_shape=[3, 4])
- # Non-strict cond is only available in v1
- if not tf2.enabled():
- self._testShape(true_fn, false_fn, shape)
- self._testReturnValues(true_fn, false_fn, value1, value2)
- self._testShape(true_fn, false_fn, [shape], strict=True)
- self._testReturnValues(true_fn, false_fn, [value1], [value2], strict=True)
-
- @test_util.run_deprecated_v1
- def test_tensors_with_partially_specified_shapes(self):
-
- def _build_branch(dtype, shape):
- a = array_ops.placeholder(dtype=dtype, shape=shape[0])
- b = array_ops.placeholder(dtype=dtype, shape=shape[1])
- c = array_ops.placeholder(dtype=dtype, shape=shape[2])
-
- def _build():
- return a, b, c
-
- return _build, (a, b, c)
-
- for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8):
- shape = (tensor_shape.TensorShape([None, 2]),
- tensor_shape.TensorShape([None]),
- tensor_shape.TensorShape([3, None]))
- fn_true, true_tensors = _build_branch(dtype, shape)
- fn_false, false_tensors = _build_branch(dtype, shape)
- self._testShape(fn_true, fn_false, shape)
- self._testReturnValues(fn_true, fn_false,
- (np.zeros([2, 2]), np.zeros(5), np.ones([3, 3])),
- (np.zeros([2, 2]), np.zeros(5), np.ones([3, 3])),
- feed_dict={true_tensors[0]: np.zeros([2, 2]),
- false_tensors[0]: np.zeros([2, 2]),
- true_tensors[1]: np.zeros([5]),
- false_tensors[1]: np.zeros([5]),
- true_tensors[2]: np.ones([3, 3]),
- false_tensors[2]: np.ones([3, 3])})
-
- @test_util.run_deprecated_v1
- def test_tensor_arrays(self):
- element_shape = tensor_shape.TensorShape([2])
- ta1 = _create_tensor_array(4, element_shape)
- ta2 = _create_tensor_array(4, element_shape)
- shape = tensor_array_ops.TensorArray
- fn_true = lambda: ta1
- fn_false = lambda: ta2
- self._testShape(fn_true, fn_false, shape)
-
- @test_util.run_deprecated_v1
- def test_tensor_array_reads(self):
- shape = tensor_shape.TensorShape([2])
- ta = _create_tensor_array(4, shape)
- fn_true = lambda: ta.read(0)
- fn_false = lambda: ta.read(1)
- self._testShape(fn_true, fn_false, shape)
-
- @test_util.run_deprecated_v1
- def test_list(self):
- shape = [tensor_shape.TensorShape([]), tensor_shape.TensorShape([]),
- tensor_shape.TensorShape([])]
- fn_true = lambda: [constant_op.constant(1), 2, variables.Variable(3.0)]
- fn_false = lambda: [constant_op.constant(3), 4, variables.Variable(5.0)]
- self._testShape(fn_true, fn_false, shape)
- self._testReturnValues(fn_true, fn_false, [1, 2, 3.0], [3, 4, 5.0])
-
- @test_util.run_v1_only("Non-strict cond is only available in v1")
- def test_non_strict(self):
- shape = tensor_shape.TensorShape([])
- fn_tensor = lambda: constant_op.constant(1)
- fn_list = lambda: [constant_op.constant(2)]
- fn_tuple = lambda: (constant_op.constant(3),)
- self._testShape(fn_tensor, fn_list, shape)
- self._testShape(fn_tensor, fn_tuple, shape)
- self._testShape(fn_list, fn_tuple, shape)
- self._testReturnValues(fn_tensor, fn_list, 1, 2)
- self._testReturnValues(fn_tensor, fn_tuple, 1, 3)
- self._testReturnValues(fn_list, fn_tuple, 2, 3)
-
- @test_util.run_v1_only("b/120553181")
- def test_singleton_strict(self):
- fn_tensor = lambda: constant_op.constant(1)
- fn_list = lambda: [constant_op.constant(2)]
- fn_tuple = lambda: (constant_op.constant(3),)
-
- with self.assertRaises(ValueError):
- control_flow_ops.cond(constant_op.constant(True), fn_tensor, fn_list,
- strict=True)
-
- with self.assertRaises(TypeError):
- control_flow_ops.cond(constant_op.constant(True), fn_list, fn_tuple,
- strict=True)
-
- with self.assertRaises(ValueError):
- control_flow_ops.case([(constant_op.constant(True), fn_tensor)], fn_list,
- strict=True)
-
- with self.assertRaises(TypeError):
- control_flow_ops.case([(constant_op.constant(True), fn_list)], fn_tuple,
- strict=True)
-
- @test_util.run_deprecated_v1
- def test_singleton_list(self):
- shape = tensor_shape.TensorShape([])
- fn_true = lambda: [constant_op.constant(1)]
- fn_false = lambda: [constant_op.constant(3)]
- # Non-strict cond is only available in v1
- if not tf2.enabled():
- self._testShape(fn_true, fn_false, shape)
- self._testReturnValues(fn_true, fn_false, 1, 3)
- self._testShape(fn_true, fn_false, [shape], strict=True)
- self._testReturnValues(fn_true, fn_false, [1], [3], strict=True)
-
- @test_util.run_deprecated_v1
- def test_singleton_tuple(self):
- shape = tensor_shape.TensorShape([])
- fn_true = lambda: (constant_op.constant(1),)
- fn_false = lambda: (constant_op.constant(3),)
- # Non-strict cond is only available in v1
- if not tf2.enabled():
- self._testShape(fn_true, fn_false, shape)
- self._testReturnValues(fn_true, fn_false, 1, 3)
- self._testShape(fn_true, fn_false, (shape,), strict=True)
- self._testReturnValues(fn_true, fn_false, (1,), (3,),
- strict=True)
-
- @test_util.run_deprecated_v1
- def test_singleton_namedtuple(self):
- shape = tensor_shape.TensorShape([])
- fn_true = lambda: SingletonTestTuple(constant_op.constant(1))
- fn_false = lambda: SingletonTestTuple(constant_op.constant(3))
- # Non-strict cond is only available in v1
- if not tf2.enabled():
- self._testShape(fn_true, fn_false, shape)
- self._testReturnValues(fn_true, fn_false, 1, 3)
- self._testShape(fn_true, fn_false, SingletonTestTuple(shape),
- strict=True)
- self._testReturnValues(fn_true, fn_false, SingletonTestTuple(1),
- SingletonTestTuple(3), strict=True)
-
- @test_util.run_deprecated_v1
- def test_tuple(self):
- shape = (tensor_shape.TensorShape([]), tensor_shape.TensorShape([]))
- fn_true = lambda: (constant_op.constant(1), 2)
- fn_false = lambda: (constant_op.constant(3), 4)
- self._testShape(fn_true, fn_false, shape)
- self._testReturnValues(fn_true, fn_false, (1, 2), (3, 4))
-
- @test_util.run_deprecated_v1
- def test_namedtuple(self):
- shape = TestTuple(tensor_shape.TensorShape([]),
- tensor_shape.TensorShape([]))
- fn_true = lambda: TestTuple(constant_op.constant(1), 2)
- fn_false = lambda: TestTuple(constant_op.constant(3), 4)
- self._testShape(fn_true, fn_false, shape)
- self._testReturnValues(fn_true, fn_false, TestTuple(1, 2), TestTuple(3, 4))
-
- @test_util.run_deprecated_v1
- def test_nested(self):
- shape = [tensor_shape.TensorShape([]),
- TestTuple(tensor_shape.TensorShape([]),
- [tensor_shape.TensorShape([]),
- tensor_shape.TensorShape([])]),
- tensor_shape.TensorShape([5, 5]),
- tensor_shape.TensorShape([])]
-
- def true_fn():
- return [constant_op.constant(1),
- TestTuple(constant_op.constant(2), [3, 4]),
- array_ops.zeros([5, 5]), 6]
-
- def false_fn():
- return [constant_op.constant(11),
- TestTuple(constant_op.constant(12), [13, 14]),
- array_ops.ones([5, 5]), 16]
-
- self._testShape(true_fn, false_fn, shape)
- self._testReturnValues(
- true_fn, false_fn,
- [1, TestTuple(2, [3, 4]), np.zeros([5, 5]), 6],
- [11, TestTuple(12, [13, 14]),
- np.ones([5, 5]), 16])
-
- @test_util.run_deprecated_v1
- def test_cond_inside_while_loop(self):
-
- def body(i, matrix):
- result_tuple, unused_matrix = control_flow_ops.cond(
- constant_op.constant(True),
- lambda: (TestTuple(matrix * 2, matrix * 4), matrix),
- lambda: (TestTuple(matrix * 4, matrix * 2), matrix))
- return [i+1, result_tuple.a]
-
- iteration, matrix = control_flow_ops.while_loop(
- lambda i, matrix: i < 10,
- body,
- loop_vars=[constant_op.constant(0),
- array_ops.ones([2, 2])])
-
- self.assertEqual(iteration.get_shape(), tensor_shape.TensorShape([]))
- self.assertEqual(matrix.get_shape(), tensor_shape.TensorShape([2, 2]))
-
-
-class CaseTest(test_util.TensorFlowTestCase):
-
- @test_util.run_deprecated_v1
- def testCase_withDefault(self):
- x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
- conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)),
- (math_ops.equal(x, 2), lambda: constant_op.constant(4))]
- default = lambda: constant_op.constant(6)
- output = control_flow_ops.case(conditions, default, exclusive=True)
- with self.cached_session() as sess:
- self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
- self.assertEqual(sess.run(output, feed_dict={x: 2}), 4)
- self.assertEqual(sess.run(output, feed_dict={x: 3}), 6)
-
- @test_util.run_deprecated_v1
- def testCase_multiple_matches_exclusive(self):
- x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
- conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)),
- (math_ops.equal(x, 2), lambda: constant_op.constant(4)),
- (math_ops.equal(x, 2), lambda: constant_op.constant(6))]
- default = lambda: constant_op.constant(8)
- output = control_flow_ops.case(conditions, default, exclusive=True)
- with self.cached_session() as sess:
- self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
- self.assertEqual(sess.run(output, feed_dict={x: 3}), 8)
- with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"):
- sess.run(output, feed_dict={x: 2})
-
- @test_util.run_deprecated_v1
- def testCase_multiple_matches_non_exclusive(self):
- x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
- conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)),
- (math_ops.equal(x, 2), lambda: constant_op.constant(4)),
- (math_ops.equal(x, 2), lambda: constant_op.constant(6))]
- default = lambda: constant_op.constant(8)
- output = control_flow_ops.case(conditions, default, exclusive=False)
- with self.cached_session() as sess:
- self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
- self.assertEqual(sess.run(output, feed_dict={x: 2}), 4)
- self.assertEqual(sess.run(output, feed_dict={x: 3}), 8)
-
- @test_util.run_deprecated_v1
- def testCase_withoutDefault(self):
- x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
- conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)),
- (math_ops.equal(x, 2), lambda: constant_op.constant(4)),
- (math_ops.equal(x, 3), lambda: constant_op.constant(6))]
- output = control_flow_ops.case(conditions, exclusive=True)
- with self.cached_session() as sess:
- self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
- self.assertEqual(sess.run(output, feed_dict={x: 2}), 4)
- self.assertEqual(sess.run(output, feed_dict={x: 3}), 6)
- with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"):
- sess.run(output, feed_dict={x: 4})
-
- @test_util.run_deprecated_v1
- def testCase_withoutDefault_oneCondition(self):
- x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
- conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2))]
- output = control_flow_ops.case(conditions, exclusive=True)
- with self.cached_session() as sess:
- self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
- with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"):
- sess.run(output, feed_dict={x: 4})
-
- @test_util.run_in_graph_and_eager_modes
- def testCase_dict(self):
- x = constant_op.constant(2)
- conditions = {
- math_ops.equal(x, 1): lambda: constant_op.constant(2),
- math_ops.equal(x, 2): lambda: constant_op.constant(4)
- }
- output = control_flow_ops.case(conditions, exclusive=True)
- self.assertEqual(4, self.evaluate(output))
-
-
-class WhileLoopTestCase(test_util.TensorFlowTestCase):
-
- @test_util.run_in_graph_and_eager_modes
- def testWhileLoopWithSingleVariable(self):
- i = constant_op.constant(0)
- c = lambda i: math_ops.less(i, 10)
- b = lambda i: math_ops.add(i, 1)
- r = control_flow_ops.while_loop(c, b, [i])
-
- self.assertEqual(self.evaluate(r), 10)
-
- @test_util.run_in_graph_and_eager_modes
- def testEagerWhileLoopWithSingleVariable_bodyReturnsTuple(self):
- i = constant_op.constant(0)
- c = lambda i: math_ops.less(i, 10)
- b = lambda i: (math_ops.add(i, 1),)
- r = control_flow_ops.while_loop(c, b, [i])
-
- # Expect a tuple since that is what the body returns.
- self.assertEqual(self.evaluate(r), (10,))
-
- @test_util.run_deprecated_v1
- def testWhileLoopSameReturnShape_False(self):
- i = constant_op.constant(0)
- c = lambda i, _: math_ops.less(i, 10)
-
- # Body returns a [tensor, []]
- b = lambda i, _: [math_ops.add(i, 1), []]
-
- # Should only return the tensor.
- r = control_flow_ops.while_loop(c, b, [i, []])
- self.assertEqual(self.evaluate(r), 10)
-
- def testWhileLoopSameReturnShape_True(self):
- i = constant_op.constant(0)
- c = lambda i, _: math_ops.less(i, 10)
-
- # Body returns a [tensor, []]
- b = lambda i, _: [math_ops.add(i, 1), []]
-
- # Should only return the original structure.
- r = control_flow_ops.while_loop(c, b, [i, []], return_same_structure=True)
- self.assertEqual(self.evaluate(r), [10, []])
-
-
-class AssertTest(test_util.TensorFlowTestCase):
-
- @test_util.run_deprecated_v1
- def testAssert(self):
- i = constant_op.constant(0)
- c = control_flow_ops.Assert(i < 10, [i, [10], [i + 1]])
- self.evaluate(c)
-
- i = constant_op.constant(10)
- c = control_flow_ops.Assert(i < 10, [i, [10], [i + 1]])
- with self.assertRaises(errors.InvalidArgumentError):
- self.evaluate(c)
-
- @test_util.run_in_graph_and_eager_modes
- def testAssertInFunction(self):
-
- @def_function.function
- def whiny(value):
- control_flow_ops.Assert(value, ["Raised false"])
- return constant_op.constant(5)
-
- with self.assertRaises(errors.InvalidArgumentError):
- self.evaluate(whiny(False))
-
- self.assertAllEqual(whiny(True), 5)
-
-if __name__ == "__main__":
- googletest.main()
diff --git a/test/TensorFlowNET.UnitTest/nest_test/NestTest.cs b/test/TensorFlowNET.UnitTest/nest_test/NestTest.cs
deleted file mode 100644
index c8928fa9..00000000
--- a/test/TensorFlowNET.UnitTest/nest_test/NestTest.cs
+++ /dev/null
@@ -1,875 +0,0 @@
-using Microsoft.VisualStudio.TestTools.UnitTesting;
-using Newtonsoft.Json.Linq;
-using NumSharp;
-using System;
-using System.Collections;
-using System.Collections.Generic;
-using Tensorflow;
-using Tensorflow.UnitTest;
-using Tensorflow.Util;
-using static Tensorflow.Binding;
-
-namespace TensorFlowNET.UnitTest.nest_test
-{
- ///
- /// excerpt of tensorflow/python/framework/util/nest_test.py
- ///
- [TestClass]
- public class NestTest : GraphModeTestBase
- {
- [TestInitialize]
- public void TestInitialize()
- {
- tf.Graph().as_default();
- }
-
- //public class PointXY
- //{
- // public double x;
- // public double y;
- //}
-
- // if attr:
- // class BadAttr(object):
- // """Class that has a non-iterable __attrs_attrs__."""
- // __attrs_attrs__ = None
-
- // @attr.s
- // class SampleAttr(object):
- // field1 = attr.ib()
- // field2 = attr.ib()
-
- // @test_util.assert_no_new_pyobjects_executing_eagerly
- // def testAttrsFlattenAndPack(self) :
- // if attr is None:
- // self.skipTest("attr module is unavailable.")
-
- // field_values = [1, 2]
- // sample_attr = NestTest.SampleAttr(* field_values)
- // self.assertFalse(nest._is_attrs(field_values))
- // self.assertTrue(nest._is_attrs(sample_attr))
- // flat = nest.flatten(sample_attr)
- // self.assertEqual(field_values, flat)
- // restructured_from_flat = nest.pack_sequence_as(sample_attr, flat)
- // self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr)
- // self.assertEqual(restructured_from_flat, sample_attr)
-
- //# Check that flatten fails if attributes are not iterable
- // with self.assertRaisesRegexp(TypeError, "object is not iterable"):
- // flat = nest.flatten(NestTest.BadAttr())
- [Ignore]
- [TestMethod]
- public void testFlattenAndPack()
- {
- object structure = new object[] { new object[] { 3, 4 }, 5, new object[] { 6, 7, new object[] { 9, 10 }, 8 } };
- var flat = new List