| @@ -12,6 +12,51 @@ namespace Tensorflow | |||||
| /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | ||||
| /// https://www.tensorflow.org/guide/graphs | /// https://www.tensorflow.org/guide/graphs | ||||
| /// </summary> | /// </summary> | ||||
| /* | |||||
| A TensorFlow computation, represented as a dataflow graph. | |||||
| A `Graph` contains a set of | |||||
| `tf.Operation` objects, | |||||
| which represent units of computation; and | |||||
| `tf.Tensor` objects, which represent | |||||
| the units of data that flow between operations. | |||||
| A default `Graph` is always registered, and accessible by calling | |||||
| `tf.get_default_graph`. | |||||
| To add an operation to the default graph, simply call one of the functions | |||||
| that defines a new `Operation`: | |||||
| ```python | |||||
| c = tf.constant(4.0) | |||||
| assert c.graph is tf.get_default_graph() | |||||
| ``` | |||||
| Another typical usage involves the | |||||
| `tf.Graph.as_default` | |||||
| context manager, which overrides the current default graph for the | |||||
| lifetime of the context: | |||||
| ```python | |||||
| g = tf.Graph() | |||||
| with g.as_default(): | |||||
| # Define operations and tensors in `g`. | |||||
| c = tf.constant(30.0) | |||||
| assert c.graph is g | |||||
| ``` | |||||
| Important note: This class *is not* thread-safe for graph construction. All | |||||
| operations should be created from a single thread, or external | |||||
| synchronization must be provided. Unless otherwise specified, all methods | |||||
| are not thread-safe. | |||||
| A `Graph` instance supports an arbitrary number of "collections" | |||||
| that are identified by name. For convenience when building a large | |||||
| graph, collections can store groups of related objects: for | |||||
| example, the `tf.Variable` uses a collection (named | |||||
| `tf.GraphKeys.GLOBAL_VARIABLES`) for | |||||
| all variables that are created during the construction of a graph. The caller | |||||
| may define additional collections by specifying a new name. | |||||
| */ | |||||
| public partial class Graph : IPython, IDisposable | public partial class Graph : IPython, IDisposable | ||||
| { | { | ||||
| private IntPtr _handle; | private IntPtr _handle; | ||||
| @@ -1,5 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.Operations | namespace Tensorflow.Operations | ||||
| @@ -92,13 +93,15 @@ namespace Tensorflow.Operations | |||||
| switch (original_result) | switch (original_result) | ||||
| { | { | ||||
| case Tensor result: | |||||
| return (original_result, _BuildCondTensor(new[] { result.op })); | |||||
| case Operation[] results: | case Operation[] results: | ||||
| return (original_result, _BuildCondTensor(results)); | return (original_result, _BuildCondTensor(results)); | ||||
| case Tensor tensor: | |||||
| return (original_result, tensor); | |||||
| case float[] fv: | case float[] fv: | ||||
| { | |||||
| var result = ops.convert_to_tensor(fv[0]); | var result = ops.convert_to_tensor(fv[0]); | ||||
| return (original_result, result ); | return (original_result, result ); | ||||
| } | |||||
| default: | default: | ||||
| return (original_result, null); | return (original_result, null); | ||||
| } | } | ||||
| @@ -114,7 +117,7 @@ namespace Tensorflow.Operations | |||||
| switch (original_result) | switch (original_result) | ||||
| { | { | ||||
| case Tensor[] results: | case Tensor[] results: | ||||
| return (original_result, results); | |||||
| return (original_result, new Tensor[] { _BuildCondTensor(results.Select(t=>t.op).ToArray())}); | |||||
| case Operation[] results: | case Operation[] results: | ||||
| return (original_result, new Tensor[] { _BuildCondTensor (results) }); | return (original_result, new Tensor[] { _BuildCondTensor (results) }); | ||||
| case float[] fv: | case float[] fv: | ||||
| @@ -27,9 +27,9 @@ namespace Tensorflow | |||||
| for (int i = 0; i < NumInputs; i++) | for (int i = 0; i < NumInputs; i++) | ||||
| { | { | ||||
| var tf_outpus = Input(i); | |||||
| var op = new Operation(tf_outpus.oper); | |||||
| retval[i] = op.outputs[tf_outpus.index]; | |||||
| var tf_outputs = Input(i); | |||||
| var op = new Operation(tf_outputs.oper); | |||||
| retval[i] = op.outputs[tf_outputs.index]; | |||||
| } | } | ||||
| _inputs = new InputList(retval); | _inputs = new InputList(retval); | ||||
| @@ -142,10 +142,29 @@ namespace Tensorflow | |||||
| return tpl.ToArray(); | return tpl.ToArray(); | ||||
| }); | }); | ||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Produces the content of `output_tensor` only after `dependencies`. | |||||
| /// | |||||
| /// In some cases, a user may want the output of an operation to be | |||||
| /// consumed externally only after some other dependencies have run | |||||
| /// first.This function ensures returns `output_tensor`, but only after all | |||||
| /// operations in `dependencies` have run.Note that this means that there is | |||||
| /// no guarantee that `output_tensor` will be evaluated after any `dependencies` | |||||
| /// have run. | |||||
| /// | |||||
| /// See also `tf.tuple` and `tf.group`. | |||||
| /// </summary> | |||||
| /// <param name="dependencies">Iterable of operations to run before this op finishes.</param> | |||||
| /// <param name="output_tensor">A `Tensor` or `IndexedSlices` that will be returned.</param> | |||||
| /// <param name="name">(Optional) A name for this operation.</param> | |||||
| /// <returns>Same as `output_tensor`.</returns> | |||||
| public static Tensor with_dependencies(Operation[] dependencies, Tensor output_tensor, string name = null) | public static Tensor with_dependencies(Operation[] dependencies, Tensor output_tensor, string name = null) | ||||
| { | { | ||||
| //TODO: missing original code | |||||
| //if context.executing_eagerly(): | |||||
| // return output_tensor | |||||
| var values = new List<object>(); | var values = new List<object>(); | ||||
| values.AddRange(dependencies); | values.AddRange(dependencies); | ||||
| values.Add(output_tensor); | values.Add(output_tensor); | ||||
| @@ -153,12 +172,15 @@ namespace Tensorflow | |||||
| return with(ops.name_scope(name, "control_dependency", values), scope => | return with(ops.name_scope(name, "control_dependency", values), scope => | ||||
| { | { | ||||
| name = scope; | name = scope; | ||||
| return with(ops.control_dependencies(dependencies), ctl => | |||||
| // TODO: missing original code | |||||
| //with ops.colocate_with(output_tensor): | |||||
| { | { | ||||
| output_tensor = ops.convert_to_tensor_or_composite(output_tensor); | |||||
| return _Identity(output_tensor, name: name); | |||||
| }); | |||||
| return with(ops.control_dependencies(dependencies), ctl => | |||||
| { | |||||
| output_tensor = ops.convert_to_tensor_or_composite(output_tensor); | |||||
| return _Identity(output_tensor, name: name); | |||||
| }); | |||||
| } | |||||
| }); | }); | ||||
| } | } | ||||
| @@ -393,8 +415,27 @@ namespace Tensorflow | |||||
| return tensors_or_flows; | return tensors_or_flows; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Returns the value of an available element of `inputs`. | |||||
| /// | |||||
| /// This op tests each of the tensors in `inputs` in turn to determine if any of | |||||
| /// them is available.If it finds an available tensor, it returns it and its | |||||
| /// index in `inputs`. | |||||
| /// | |||||
| /// It is an error if more than one tensor in `inputs` is available.If no tensor | |||||
| /// in `inputs` is available, the returned tensor and index are not set. | |||||
| /// | |||||
| /// This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of | |||||
| /// `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices | |||||
| /// before merging. | |||||
| /// </summary> | |||||
| /// <param name="inputs">inputs: The input tensors, at most one of which is available.</param> | |||||
| /// <param name="name">A name for this operation (optional).</param> | |||||
| /// <returns></returns> | |||||
| public static Tensor merge(Tensor[] inputs, string name = null) | public static Tensor merge(Tensor[] inputs, string name = null) | ||||
| { | { | ||||
| if (inputs.Any(x => x == null)) | |||||
| throw new ValueError($"At least one of the merge inputs is null: {inputs}"); | |||||
| return with(ops.name_scope(name, "Merge", inputs), scope => | return with(ops.name_scope(name, "Merge", inputs), scope => | ||||
| { | { | ||||
| name = scope; | name = scope; | ||||
| @@ -49,19 +49,59 @@ namespace Tensorflow | |||||
| return get_default_graph().get_collection_ref(key); | return get_default_graph().get_collection_ref(key); | ||||
| } | } | ||||
| private static Graph default_graph; | |||||
| private static Graph default_graph; | |||||
| /// <summary> | |||||
| /// Returns the default graph for the current thread. | |||||
| /// | |||||
| /// The returned graph will be the innermost graph on which a | |||||
| /// `Graph.as_default()` context has been entered, or a global default | |||||
| /// graph if none has been explicitly created. | |||||
| /// | |||||
| /// NOTE: The default graph is a property of the current thread.If you | |||||
| /// create a new thread, and wish to use the default graph in that | |||||
| /// thread, you must explicitly add a `with g.as_default():` in that | |||||
| /// thread's function. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public static Graph get_default_graph() | public static Graph get_default_graph() | ||||
| { | { | ||||
| //TODO: original source indicates there should be a _default_graph_stack! | |||||
| //return _default_graph_stack.get_default() | |||||
| if (default_graph == null) | if (default_graph == null) | ||||
| default_graph = tf.Graph(); | default_graph = tf.Graph(); | ||||
| return default_graph; | return default_graph; | ||||
| } | } | ||||
| public static Graph set_default_graph(Graph graph) | public static Graph set_default_graph(Graph graph) | ||||
| { | { | ||||
| //TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack! | |||||
| default_graph = graph; | default_graph = graph; | ||||
| return default_graph; | return default_graph; | ||||
| } | |||||
| /// <summary> | |||||
| /// Clears the default graph stack and resets the global default graph. | |||||
| /// | |||||
| /// NOTE: The default graph is a property of the current thread.This | |||||
| /// function applies only to the current thread.Calling this function while | |||||
| /// a `tf.Session` or `tf.InteractiveSession` is active will result in undefined | |||||
| /// behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects | |||||
| /// after calling this function will result in undefined behavior. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public static void reset_default_graph() | |||||
| { | |||||
| //TODO: original source indicates there should be a _default_graph_stack! | |||||
| //if (!_default_graph_stack.is_cleared()) | |||||
| // throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " + | |||||
| // "nested graphs. If you need a cleared graph, " + | |||||
| // "exit the nesting and create a new graph."); | |||||
| //_default_graph_stack.reset(); | |||||
| if (default_graph!=null) | |||||
| default_graph.Dispose(); | |||||
| default_graph = tf.Graph(); | |||||
| } | } | ||||
| public static Graph _get_graph_from_inputs(List<Tensor> op_input_list, Graph graph = null) | public static Graph _get_graph_from_inputs(List<Tensor> op_input_list, Graph graph = null) | ||||
| { | { | ||||
| foreach(var op_input in op_input_list) | foreach(var op_input in op_input_list) | ||||
| @@ -13,7 +13,7 @@ namespace TensorFlowNET.UnitTest | |||||
| /// </summary> | /// </summary> | ||||
| public class PythonTest : Python | public class PythonTest : Python | ||||
| { | { | ||||
| public void assertItemsEqual(ICollection expected, ICollection given) | |||||
| public void assertItemsEqual(ICollection given, ICollection expected) | |||||
| { | { | ||||
| Assert.IsNotNull(expected); | Assert.IsNotNull(expected); | ||||
| Assert.IsNotNull(given); | Assert.IsNotNull(given); | ||||
| @@ -6,7 +6,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| namespace TensorFlowNET.UnitTest | |||||
| namespace TensorFlowNET.UnitTest.ops_test | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// excerpt of tensorflow/python/framework/ops_test.py | /// excerpt of tensorflow/python/framework/ops_test.py | ||||
| @@ -157,8 +157,8 @@ namespace TensorFlowNET.UnitTest | |||||
| }); | }); | ||||
| }); | }); | ||||
| }); | }); | ||||
| assertItemsEqual(new[] { a_1.op, a_2.op, a_3.op, a_4.op }, b_1.op.control_inputs); | |||||
| assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs); | |||||
| assertItemsEqual(b_1.op.control_inputs, new[] { a_1.op, a_2.op, a_3.op, a_4.op }); | |||||
| assertItemsEqual(b_2.op.control_inputs, b_1.op.control_inputs); | |||||
| } | } | ||||
| [TestMethod] | [TestMethod] | ||||
| @@ -200,6 +200,7 @@ namespace TensorFlowNET.UnitTest | |||||
| b_none2 = constant_op.constant(12.0); | b_none2 = constant_op.constant(12.0); | ||||
| }); | }); | ||||
| }); | }); | ||||
| // Note assertItemsEqual(given, expected), expected and given parameters should be swapped below | |||||
| assertItemsEqual(new[] { a_3.op, a_4.op }, b_3_4.op.control_inputs); | assertItemsEqual(new[] { a_3.op, a_4.op }, b_3_4.op.control_inputs); | ||||
| assertItemsEqual(new[] { a_3.op }, b_3.op.control_inputs); | assertItemsEqual(new[] { a_3.op }, b_3.op.control_inputs); | ||||
| assertItemsEqual(new object[0], b_none.op.control_inputs); | assertItemsEqual(new object[0], b_none.op.control_inputs); | ||||
| @@ -256,6 +257,7 @@ namespace TensorFlowNET.UnitTest | |||||
| }); | }); | ||||
| }); | }); | ||||
| // Note assertItemsEqual(given, expected), expected and given parameters should be swapped below | |||||
| assertItemsEqual(new[] {a_1.op}, b_1.op.control_inputs); | assertItemsEqual(new[] {a_1.op}, b_1.op.control_inputs); | ||||
| assertItemsEqual(new[] {a_1.op, a_2.op}, b_2.op.control_inputs); | assertItemsEqual(new[] {a_1.op, a_2.op}, b_2.op.control_inputs); | ||||
| assertItemsEqual(new[] { a_1.op, a_2.op}, b_3.op.control_inputs); | assertItemsEqual(new[] { a_1.op, a_2.op}, b_3.op.control_inputs); | ||||
| @@ -1,10 +1,12 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Operations; | |||||
| namespace TensorFlowNET.UnitTest | |||||
| namespace TensorFlowNET.UnitTest.ops_test | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// excerpt of tensorflow/python/framework/ops_test.py | /// excerpt of tensorflow/python/framework/ops_test.py | ||||
| @@ -19,21 +21,21 @@ namespace TensorFlowNET.UnitTest | |||||
| [TestClass] | [TestClass] | ||||
| public class CreateOpFromTfOperationTest : PythonTest | public class CreateOpFromTfOperationTest : PythonTest | ||||
| { | { | ||||
| [TestMethod] | [TestMethod] | ||||
| public void TestShape() | public void TestShape() | ||||
| { | { | ||||
| var graph = tf.Graph().as_default(); | var graph = tf.Graph().as_default(); | ||||
| with<Graph>(graph, g => | with<Graph>(graph, g => | ||||
| { | { | ||||
| var x = constant_op.constant(new [,] { {1, 2, 3}, {4, 5, 6}}); | |||||
| var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]); | |||||
| var x = constant_op.constant(new[,] { { 1, 2, 3 }, { 4, 5, 6 } }); | |||||
| var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] { x }, new Operation[0]); | |||||
| var op = g._create_op_from_tf_operation(c_op); | var op = g._create_op_from_tf_operation(c_op); | ||||
| Assert.AreEqual("myop", op.name); | Assert.AreEqual("myop", op.name); | ||||
| Assert.AreEqual("Identity", op.type); | Assert.AreEqual("Identity", op.type); | ||||
| Assert.AreEqual(1, len(op.outputs)); | Assert.AreEqual(1, len(op.outputs)); | ||||
| assertItemsEqual(new []{2, 3}, op.outputs[0].shape); | |||||
| assertItemsEqual(new[] { 2, 3 }, op.outputs[0].shape); | |||||
| }); | }); | ||||
| } | } | ||||
| @@ -47,7 +49,7 @@ namespace TensorFlowNET.UnitTest | |||||
| //var (c_op2, op_desc1) = ops._create_c_op(g, ops._NodeDef("Const", "myop_1"), new Tensor[0], new Operation[0]); | //var (c_op2, op_desc1) = ops._create_c_op(g, ops._NodeDef("Const", "myop_1"), new Tensor[0], new Operation[0]); | ||||
| //var op = g._create_op_from_tf_operation(c_op); | //var op = g._create_op_from_tf_operation(c_op); | ||||
| //var op2 = g._create_op_from_tf_operation(c_op2); | //var op2 = g._create_op_from_tf_operation(c_op2); | ||||
| var op = constant_op.constant(0, name:"myop").op; | |||||
| var op = constant_op.constant(0, name: "myop").op; | |||||
| var op2 = constant_op.constant(0, name: "myop_1").op; | var op2 = constant_op.constant(0, name: "myop_1").op; | ||||
| // Create ops with same names as op1 and op2. We expect the new names to be | // Create ops with same names as op1 and op2. We expect the new names to be | ||||
| @@ -62,7 +64,7 @@ namespace TensorFlowNET.UnitTest | |||||
| }); | }); | ||||
| } | } | ||||
| [Ignore("Something is not right, Switch gets not inserted correctly?")] | |||||
| [Ignore("Switch op gets not inserted correctly in the graph")] | |||||
| [TestMethod] | [TestMethod] | ||||
| public void TestCond() | public void TestCond() | ||||
| { | { | ||||
| @@ -91,8 +93,7 @@ namespace TensorFlowNET.UnitTest | |||||
| self.assertEqual(op_input.inputs[0], x); | self.assertEqual(op_input.inputs[0], x); | ||||
| self.assertEqual(op.graph, g); | self.assertEqual(op.graph, g); | ||||
| self.assertIsNotNone(op._get_control_flow_context()); | self.assertIsNotNone(op._get_control_flow_context()); | ||||
| // TODO: op._get_control_flow_context().name not implemented | |||||
| //self.assertEqual(op._get_control_flow_context().name, "cond/cond_text"); | |||||
| self.assertEqual((op._get_control_flow_context() as ControlFlowContext).name, "cond/cond_text"); | |||||
| }); | }); | ||||
| /* | /* | ||||
| @test_util.run_v1_only("b/120545219") | @test_util.run_v1_only("b/120545219") | ||||
| @@ -126,7 +127,39 @@ namespace TensorFlowNET.UnitTest | |||||
| # pylint: enable=protected-access | # pylint: enable=protected-access | ||||
| */ | */ | ||||
| } | } | ||||
| /* | |||||
| [Ignore("Todo: Port")] | |||||
| [TestMethod] | |||||
| public void TestWhileLoop() | |||||
| { | |||||
| var graph = tf.Graph().as_default(); | |||||
| Operation x=null; | |||||
| with<Graph>(graph, g => | |||||
| { | |||||
| x = constant_op.constant(42); | |||||
| var body = new Func<int, int>(i => | |||||
| { | |||||
| ops._create_c_op(ops.get_default_graph(), ops._NodeDef("Identity", "myloop/myop"), new[] {x}, | |||||
| new Operation[0]); | |||||
| var new_ops = g._add_new_tf_operations(); | |||||
| self.assertEqual(len(new_ops), 1); | |||||
| return i; | |||||
| }); | |||||
| // TODO: port control_flow_ops.while_loop | |||||
| //control_flow_ops.while_loop( i => i < 10, body, new int[]{0}, name = "myloop"); | |||||
| }); | |||||
| var op = graph.get_operation_by_name("myloop/myop"); | |||||
| self.assertIsNotNone(op); | |||||
| self.assertEqual(op.name, "myloop/myop"); | |||||
| self.assertEqual(op.type, "Identity"); | |||||
| self.assertEqual(op.outputs.Length, 0); | |||||
| var op_input = op.inputs[0].op; | |||||
| self.assertEqual(op_input.type, "Enter"); | |||||
| self.assertItemsEqual(op_input.inputs.OfType<Operation>().ToArray(), new[] {x}); | |||||
| self.assertEqual(op.graph, graph); | |||||
| self.assertIsNotNone(op._get_control_flow_context()); | |||||
| self.assertEqual(((ControlFlowContext)op._get_control_flow_context()).name, "myloop/while_context"); | |||||
| /* | |||||
| @test_util.run_v1_only("b/120545219") | @test_util.run_v1_only("b/120545219") | ||||
| def testWhileLoop(self): | def testWhileLoop(self): | ||||
| g = ops.Graph() | g = ops.Graph() | ||||
| @@ -156,8 +189,15 @@ namespace TensorFlowNET.UnitTest | |||||
| self.assertEqual(op._get_control_flow_context().name, | self.assertEqual(op._get_control_flow_context().name, | ||||
| "myloop/while_context") | "myloop/while_context") | ||||
| # pylint: enable=protected-access | # pylint: enable=protected-access | ||||
| */ | |||||
| } | |||||
| @test_util.run_v1_only("b/120545219") | |||||
| [Ignore("Todo: Port")] | |||||
| [TestMethod] | |||||
| public void TestWhileLoopWithInternalControlDep() | |||||
| { | |||||
| /* | |||||
| @test_util.run_v1_only("b/120545219") | |||||
| def testWhileLoopWithInternalControlDep(self): | def testWhileLoopWithInternalControlDep(self): | ||||
| g = ops.Graph() | g = ops.Graph() | ||||
| with g.as_default(): | with g.as_default(): | ||||
| @@ -180,7 +220,14 @@ namespace TensorFlowNET.UnitTest | |||||
| self.assertIsNotNone(c) | self.assertIsNotNone(c) | ||||
| # Internal control dep is preserved | # Internal control dep is preserved | ||||
| self.assertEqual(op.control_inputs, [c]) | self.assertEqual(op.control_inputs, [c]) | ||||
| */ | |||||
| } | |||||
| [Ignore("Todo: Port")] | |||||
| [TestMethod] | |||||
| public void TestWhileLoopWithExternalControlDep() | |||||
| { | |||||
| /* | |||||
| @test_util.run_v1_only("b/120545219") | @test_util.run_v1_only("b/120545219") | ||||
| def testWhileLoopWithExternalControlDep(self): | def testWhileLoopWithExternalControlDep(self): | ||||
| g = ops.Graph() | g = ops.Graph() | ||||
| @@ -203,8 +250,8 @@ namespace TensorFlowNET.UnitTest | |||||
| # External control dep is removed and replaced with internal control dep | # External control dep is removed and replaced with internal control dep | ||||
| self.assertNotEqual(op.control_inputs[0], c.op) | self.assertNotEqual(op.control_inputs[0], c.op) | ||||
| self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context()) | self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context()) | ||||
| */ | |||||
| } | |||||
| */ | |||||
| } | |||||
| } | } | ||||
| } | |||||
| @@ -0,0 +1,200 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using Tensorflow; | |||||
| using Tensorflow.Operations; | |||||
| namespace TensorFlowNET.UnitTest.ops_test | |||||
| { | |||||
| /// <summary> | |||||
| /// excerpt of tensorflow/python/framework/ops_test.py | |||||
| /// </summary> | |||||
| [TestClass] | |||||
| public class GraphTest : PythonTest | |||||
| { | |||||
| [TestInitialize] | |||||
| public void SetUp() | |||||
| { | |||||
| ops.reset_default_graph(); | |||||
| } | |||||
| [TestCleanup] | |||||
| public void TearDown() | |||||
| { | |||||
| ops.reset_default_graph(); | |||||
| } | |||||
| private void _AssertDefault(Graph expected) { | |||||
| Assert.AreSame(ops.get_default_graph(), expected); | |||||
| } | |||||
| [Ignore("Todo: Port")] | |||||
| [TestMethod] | |||||
| public void testResetDefaultGraphNesting() | |||||
| { | |||||
| /* | |||||
| def testResetDefaultGraphNesting(self): | |||||
| g0 = ops.Graph() | |||||
| with self.assertRaises(AssertionError): | |||||
| with g0.as_default(): | |||||
| ops.reset_default_graph() | |||||
| */ | |||||
| } | |||||
| [Ignore("Todo: Port")] | |||||
| [TestMethod] | |||||
| public void testGraphContextManagerCancelsEager() | |||||
| { | |||||
| /* | |||||
| def testGraphContextManagerCancelsEager(self): | |||||
| with context.eager_mode(): | |||||
| with ops.Graph().as_default(): | |||||
| self.assertFalse(context.executing_eagerly()) | |||||
| */ | |||||
| } | |||||
| [Ignore("Todo: Port")] | |||||
| [TestMethod] | |||||
| public void testGraphContextManager() | |||||
| { | |||||
| /* | |||||
| def testGraphContextManager(self): | |||||
| g0 = ops.Graph() | |||||
| with g0.as_default() as g1: | |||||
| self.assertIs(g0, g1) | |||||
| */ | |||||
| } | |||||
| [Ignore("Todo: Port")] | |||||
| [TestMethod] | |||||
| public void testDefaultGraph() | |||||
| { | |||||
| /* | |||||
| def testDefaultGraph(self): | |||||
| orig = ops.get_default_graph() | |||||
| self._AssertDefault(orig) | |||||
| g0 = ops.Graph() | |||||
| self._AssertDefault(orig) | |||||
| context_manager_0 = g0.as_default() | |||||
| self._AssertDefault(orig) | |||||
| with context_manager_0 as g0: | |||||
| self._AssertDefault(g0) | |||||
| with ops.Graph().as_default() as g1: | |||||
| self._AssertDefault(g1) | |||||
| self._AssertDefault(g0) | |||||
| self._AssertDefault(orig) | |||||
| */ | |||||
| } | |||||
| [Ignore("Todo: Port")] | |||||
| [TestMethod] | |||||
| public void testPreventFeeding() | |||||
| { | |||||
| /* | |||||
| def testPreventFeeding(self): | |||||
| g = ops.Graph() | |||||
| a = constant_op.constant(2.0) | |||||
| self.assertTrue(g.is_feedable(a)) | |||||
| g.prevent_feeding(a) | |||||
| self.assertFalse(g.is_feedable(a)) | |||||
| */ | |||||
| } | |||||
| [Ignore("Todo: Port")] | |||||
| [TestMethod] | |||||
| public void testAsGraphElementConversions() | |||||
| { | |||||
| /* | |||||
| def testAsGraphElementConversions(self): | |||||
| class ConvertibleObj(object): | |||||
| def _as_graph_element(self): | |||||
| return "FloatOutput:0" | |||||
| class NonConvertibleObj(object): | |||||
| pass | |||||
| g = ops.Graph() | |||||
| a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) | |||||
| self.assertEqual(a, g.as_graph_element(ConvertibleObj())) | |||||
| with self.assertRaises(TypeError): | |||||
| g.as_graph_element(NonConvertibleObj()) | |||||
| */ | |||||
| } | |||||
| [Ignore("Todo: Port")] | |||||
| [TestMethod] | |||||
| public void testGarbageCollected() | |||||
| { | |||||
| /* | |||||
| # Regression test against creating custom __del__ functions in classes | |||||
| # involved in cyclic references, e.g. Graph and Operation. (Python won't gc | |||||
| # cycles that require calling a __del__ method, because the __del__ method can | |||||
| # theoretically increase the object's refcount to "save" it from gc, and any | |||||
| # already-deleted objects in the cycle would have be to restored.) | |||||
| def testGarbageCollected(self): | |||||
| # Create a graph we can delete and a weak reference to monitor if it's gc'd | |||||
| g = ops.Graph() | |||||
| g_ref = weakref.ref(g) | |||||
| # Create some ops | |||||
| with g.as_default(): | |||||
| a = constant_op.constant(2.0) | |||||
| b = constant_op.constant(3.0) | |||||
| c = math_ops.add(a, b) | |||||
| # Create a session we can delete | |||||
| with session.Session(graph=g) as sess: | |||||
| self.evaluate(c) | |||||
| # Delete all references and trigger gc | |||||
| del g | |||||
| del a | |||||
| del b | |||||
| del c | |||||
| del sess | |||||
| gc.collect() | |||||
| self.assertIsNone(g_ref()) | |||||
| */ | |||||
| } | |||||
| [Ignore("Todo: Port")] | |||||
| [TestMethod] | |||||
| public void testRunnableAfterInvalidShape() | |||||
| { | |||||
| /* | |||||
| def testRunnableAfterInvalidShape(self): | |||||
| with ops.Graph().as_default(): | |||||
| with self.assertRaises(ValueError): | |||||
| math_ops.add([1, 2], [1, 2, 3]) | |||||
| a = constant_op.constant(1) | |||||
| with session.Session() as sess: | |||||
| self.evaluate(a) | |||||
| */ | |||||
| } | |||||
| [Ignore("Todo: Port")] | |||||
| [TestMethod] | |||||
| public void testRunnableAfterInvalidShapeWithKernelLabelMap() | |||||
| { | |||||
| /* | |||||
| def testRunnableAfterInvalidShapeWithKernelLabelMap(self): | |||||
| g = ops.Graph() | |||||
| with g.as_default(): | |||||
| with g._kernel_label_map({"KernelLabelRequired": "overload_1"}): | |||||
| with self.assertRaises(ValueError): | |||||
| test_ops.kernel_label_required(1) | |||||
| a = constant_op.constant(1) | |||||
| with session.Session() as sess: | |||||
| self.evaluate(a) | |||||
| */ | |||||
| } | |||||
| } | |||||
| } | |||||