| @@ -27,7 +27,7 @@ namespace Tensorflow | |||||
| // 100K gradient 44M. | // 100K gradient 44M. | ||||
| mm.Execute(10, 10 * batchSize, cases.Gradient); | mm.Execute(10, 10 * batchSize, cases.Gradient); | ||||
| // 120M | |||||
| // 95M | |||||
| Console.WriteLine("Finished."); | Console.WriteLine("Finished."); | ||||
| Console.ReadLine(); | Console.ReadLine(); | ||||
| } | } | ||||
| @@ -0,0 +1,30 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2020 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using NumSharp; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public partial class tensorflow | |||||
| { | |||||
| public CompatApi compat { get; } = new CompatApi(); | |||||
| public class CompatApi | |||||
| { | |||||
| public CompatV1Api v1 { get; } = new CompatV1Api(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,30 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2020 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| using Tensorflow.Eager; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class CompatV1Api | |||||
| { | |||||
| public void disable_eager_execution() | |||||
| { | |||||
| tf.context.default_execution_mode = Context.GRAPH_MODE; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -259,7 +259,8 @@ namespace Tensorflow | |||||
| public Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, | public Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, | ||||
| TF_DataType[] input_types = null, string name = null, | TF_DataType[] input_types = null, string name = null, | ||||
| Dictionary<string, AttrValue> attrs = null, OpDef op_def = null) | |||||
| Dictionary<string, AttrValue> attrs = null, OpDef op_def = null, | |||||
| bool compute_device = true) | |||||
| { | { | ||||
| if (inputs == null) | if (inputs == null) | ||||
| inputs = new Tensor[0]; | inputs = new Tensor[0]; | ||||
| @@ -270,7 +271,7 @@ namespace Tensorflow | |||||
| // If a names ends with a '/' it is a "name scope" and we use it as-is, | // If a names ends with a '/' it is a "name scope" and we use it as-is, | ||||
| // after removing the trailing '/'. | // after removing the trailing '/'. | ||||
| name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name); | name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name); | ||||
| var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); | |||||
| var node_def = ops._NodeDef(op_type, name, attrs: attrs); | |||||
| var input_ops = inputs.Select(x => x.op).ToArray(); | var input_ops = inputs.Select(x => x.op).ToArray(); | ||||
| var control_inputs = _control_dependencies_for_inputs(input_ops); | var control_inputs = _control_dependencies_for_inputs(input_ops); | ||||
| @@ -284,7 +285,7 @@ namespace Tensorflow | |||||
| original_op: null, | original_op: null, | ||||
| op_def: op_def); | op_def: op_def); | ||||
| _create_op_helper(op, true); | |||||
| _create_op_helper(op, compute_device); | |||||
| /*Console.Write($"create_op: {op_type} '{node_def.Name}'"); | /*Console.Write($"create_op: {op_type} '{node_def.Name}'"); | ||||
| Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}"); | Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}"); | ||||
| @@ -40,8 +40,8 @@ namespace Tensorflow | |||||
| public void _add_control_input(Operation op) | public void _add_control_input(Operation op) | ||||
| { | { | ||||
| //c_api.TF_AddControlInput(_operDesc, op); | |||||
| c_api.AddControlInput(graph, _handle, op); | |||||
| c_api.TF_AddControlInput(OpDesc, op); | |||||
| //c_api.AddControlInput(graph, _handle, op); | |||||
| } | } | ||||
| public void _add_control_inputs(Operation[] ops) | public void _add_control_inputs(Operation[] ops) | ||||
| @@ -64,7 +64,7 @@ namespace Tensorflow | |||||
| public string Device => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | public string Device => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | ||||
| bool _is_stateful; | bool _is_stateful; | ||||
| public OperationDescription OpDesc { get; set; } | |||||
| public NodeDef node_def | public NodeDef node_def | ||||
| { | { | ||||
| @@ -170,7 +170,7 @@ namespace Tensorflow | |||||
| op_def = g.GetOpDef(node_def.Op); | op_def = g.GetOpDef(node_def.Op); | ||||
| var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | ||||
| _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | |||||
| (_handle, OpDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | |||||
| _is_stateful = op_def.IsStateful; | _is_stateful = op_def.IsStateful; | ||||
| // Initialize self._outputs. | // Initialize self._outputs. | ||||
| @@ -187,9 +187,6 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern unsafe ulong TF_StringEncode(byte* src, ulong src_len, sbyte* dst, ulong dst_len, SafeStatusHandle status); | public static extern unsafe ulong TF_StringEncode(byte* src, ulong src_len, sbyte* dst, ulong dst_len, SafeStatusHandle status); | ||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern unsafe ulong TF_StringEncode(IntPtr src, ulong src_len, IntPtr dst, ulong dst_len, SafeStatusHandle status); | |||||
| /// <summary> | /// <summary> | ||||
| /// Decode a string encoded using TF_StringEncode. | /// Decode a string encoded using TF_StringEncode. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -199,9 +196,6 @@ namespace Tensorflow | |||||
| /// <param name="dst_len">size_t*</param> | /// <param name="dst_len">size_t*</param> | ||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern ulong TF_StringDecode(IntPtr src, ulong src_len, IntPtr dst, ref ulong dst_len, SafeStatusHandle status); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern unsafe ulong TF_StringDecode(byte* src, ulong src_len, byte** dst, ref ulong dst_len, SafeStatusHandle status); | public static extern unsafe ulong TF_StringDecode(byte* src, ulong src_len, byte** dst, ref ulong dst_len, SafeStatusHandle status); | ||||
| @@ -155,7 +155,7 @@ namespace Tensorflow | |||||
| /// </param> | /// </param> | ||||
| /// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param> | /// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param> | ||||
| /// <returns>A wrapped TF_Operation*.</returns> | /// <returns>A wrapped TF_Operation*.</returns> | ||||
| public static IntPtr _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs) | |||||
| public static (IntPtr, OperationDescription) _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs) | |||||
| { | { | ||||
| lock (Locks.ProcessWide) | lock (Locks.ProcessWide) | ||||
| { | { | ||||
| @@ -198,7 +198,7 @@ namespace Tensorflow | |||||
| status.Check(true); | status.Check(true); | ||||
| return c_op; | |||||
| return (c_op, op_desc); | |||||
| } | } | ||||
| } | } | ||||
| @@ -207,7 +207,7 @@ namespace Tensorflow | |||||
| return graph.GetOpDef(type); | return graph.GetOpDef(type); | ||||
| } | } | ||||
| public static NodeDef _NodeDef(string op_type, string name, string device = "", Dictionary<string, AttrValue> attrs = null) | |||||
| public static NodeDef _NodeDef(string op_type, string name, Dictionary<string, AttrValue> attrs = null) | |||||
| { | { | ||||
| var node_def = new NodeDef(); | var node_def = new NodeDef(); | ||||
| node_def.Op = op_type; | node_def.Op = op_type; | ||||
| @@ -4,13 +4,13 @@ using System.Collections.Generic; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.UnitTest; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest.Basics | namespace TensorFlowNET.UnitTest.Basics | ||||
| { | { | ||||
| [Ignore] | |||||
| [TestClass] | [TestClass] | ||||
| public class QueueTest | |||||
| public class QueueTest : GraphModeTestBase | |||||
| { | { | ||||
| [TestMethod] | [TestMethod] | ||||
| public void PaddingFIFOQueue() | public void PaddingFIFOQueue() | ||||
| @@ -10,7 +10,6 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
| [TestClass] | [TestClass] | ||||
| public class VariableTest | public class VariableTest | ||||
| { | { | ||||
| [Ignore] | |||||
| [TestMethod] | [TestMethod] | ||||
| public void NewVariable() | public void NewVariable() | ||||
| { | { | ||||
| @@ -34,7 +33,6 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
| Assert.AreEqual(4, (int)y.numpy()); | Assert.AreEqual(4, (int)y.numpy()); | ||||
| } | } | ||||
| [Ignore] | |||||
| [TestMethod] | [TestMethod] | ||||
| public void Assign1() | public void Assign1() | ||||
| { | { | ||||
| @@ -0,0 +1,24 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using TensorFlowNET.UnitTest; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.UnitTest | |||||
| { | |||||
| public class GraphModeTestBase : PythonTest | |||||
| { | |||||
| [TestInitialize] | |||||
| public void TestInit() | |||||
| { | |||||
| tf.compat.v1.disable_eager_execution(); | |||||
| } | |||||
| [TestCleanup] | |||||
| public void TestClean() | |||||
| { | |||||
| tf.enable_eager_execution(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,16 +1,16 @@ | |||||
| using System; | using System; | ||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.UnitTest; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest.Basics | namespace TensorFlowNET.UnitTest.Basics | ||||
| { | { | ||||
| [TestClass] | [TestClass] | ||||
| public class NameScopeTest | |||||
| public class NameScopeTest : GraphModeTestBase | |||||
| { | { | ||||
| string name = ""; | string name = ""; | ||||
| [Ignore] | |||||
| [TestMethod] | [TestMethod] | ||||
| public void NestedNameScope() | public void NestedNameScope() | ||||
| { | { | ||||
| @@ -7,12 +7,12 @@ using Tensorflow; | |||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| using Buffer = Tensorflow.Buffer; | using Buffer = Tensorflow.Buffer; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using Tensorflow.UnitTest; | |||||
| namespace TensorFlowNET.UnitTest.Basics | namespace TensorFlowNET.UnitTest.Basics | ||||
| { | { | ||||
| [Ignore] | |||||
| [TestClass] | [TestClass] | ||||
| public class OperationsTest | |||||
| public class OperationsTest : GraphModeTestBase | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Port from tensorflow\c\c_api_test.cc | /// Port from tensorflow\c\c_api_test.cc | ||||
| @@ -726,6 +726,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
| #endregion | #endregion | ||||
| } | } | ||||
| [Ignore] | |||||
| [TestMethod] | [TestMethod] | ||||
| public void divOpTests() | public void divOpTests() | ||||
| { | { | ||||
| @@ -3,6 +3,7 @@ using NumSharp; | |||||
| using System; | using System; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using System.Text; | |||||
| using Tensorflow; | using Tensorflow; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -160,23 +161,6 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
| Assert.AreEqual(6.0, (double)c); | Assert.AreEqual(6.0, (double)c); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void StringEncode() | |||||
| { | |||||
| string str = "Hello, TensorFlow.NET!"; | |||||
| var handle = Marshal.StringToHGlobalAnsi(str); | |||||
| var dst_len = c_api.TF_StringEncodedSize((ulong)str.Length); | |||||
| Assert.AreEqual(dst_len, (ulong)23); | |||||
| IntPtr dst = Marshal.AllocHGlobal((int)dst_len); | |||||
| var encoded_len = c_api.TF_StringEncode(handle, (ulong)str.Length, dst, dst_len, status.Handle); | |||||
| Assert.AreEqual((ulong)23, encoded_len); | |||||
| Assert.AreEqual(status.Code, TF_Code.TF_OK); | |||||
| string encoded_str = Marshal.PtrToStringUTF8(dst + sizeof(byte)); | |||||
| Assert.AreEqual(encoded_str, str); | |||||
| Assert.AreEqual(str.Length, Marshal.ReadByte(dst)); | |||||
| // c_api.TF_StringDecode(dst, (ulong)str.Length, IntPtr.Zero, ref dst_len, status.Handle); | |||||
| } | |||||
| [TestMethod] | [TestMethod] | ||||
| public void Reshape() | public void Reshape() | ||||
| { | { | ||||
| @@ -1,5 +1,6 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.UnitTest; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest.control_flow_ops_test | namespace TensorFlowNET.UnitTest.control_flow_ops_test | ||||
| @@ -7,10 +8,10 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||||
| /// <summary> | /// <summary> | ||||
| /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py | /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py | ||||
| /// </summary> | /// </summary> | ||||
| [Ignore] | |||||
| [TestClass] | [TestClass] | ||||
| public class CondTestCases : PythonTest | |||||
| public class CondTestCases : GraphModeTestBase | |||||
| { | { | ||||
| [Ignore("Dependent on UpdateEdge")] | |||||
| [TestMethod] | [TestMethod] | ||||
| public void testCondTrue_ConstOnly() | public void testCondTrue_ConstOnly() | ||||
| { | { | ||||
| @@ -49,6 +50,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||||
| } | } | ||||
| } | } | ||||
| [Ignore("Dependent on UpdateEdge")] | |||||
| [TestMethod] | [TestMethod] | ||||
| public void testCondTrue() | public void testCondTrue() | ||||
| { | { | ||||
| @@ -65,6 +67,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||||
| assertEquals(result, 34); | assertEquals(result, 34); | ||||
| } | } | ||||
| [Ignore("Dependent on UpdateEdge")] | |||||
| [TestMethod] | [TestMethod] | ||||
| public void testCondFalse() | public void testCondFalse() | ||||
| { | { | ||||
| @@ -1,14 +1,14 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.UnitTest; | |||||
| namespace TensorFlowNET.UnitTest.control_flow_ops_test | namespace TensorFlowNET.UnitTest.control_flow_ops_test | ||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py | /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py | ||||
| /// </summary> | /// </summary> | ||||
| [Ignore] | |||||
| [TestClass] | [TestClass] | ||||
| public class ShapeTestCase : PythonTest | |||||
| public class ShapeTestCase : GraphModeTestBase | |||||
| { | { | ||||
| [TestMethod] | [TestMethod] | ||||
| @@ -1,24 +1,24 @@ | |||||
| using System; | using System; | ||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.UnitTest; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest.control_flow_ops_test | namespace TensorFlowNET.UnitTest.control_flow_ops_test | ||||
| { | { | ||||
| [TestClass] | [TestClass] | ||||
| public class WhileContextTestCase : PythonTest | |||||
| public class WhileContextTestCase : GraphModeTestBase | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// https://www.tensorflow.org/api_docs/python/tf/while_loop | /// https://www.tensorflow.org/api_docs/python/tf/while_loop | ||||
| /// </summary> | /// </summary> | ||||
| [Ignore] | |||||
| [TestMethod] | [TestMethod] | ||||
| public void SimpleWhileLoop() | public void SimpleWhileLoop() | ||||
| { | { | ||||
| var i = constant_op.constant(0, name: "i"); | var i = constant_op.constant(0, name: "i"); | ||||
| var c = new Func<Tensor, Tensor>(x => tf.less(x, 10, name: "c")); | var c = new Func<Tensor, Tensor>(x => tf.less(x, 10, name: "c")); | ||||
| var b = new Func<Tensor, Tensor>(x => tf.add(x, 1, name: "c")); | var b = new Func<Tensor, Tensor>(x => tf.add(x, 1, name: "c")); | ||||
| //var r = control_flow_ops.while_loop(c, b, i); | |||||
| // var r = control_flow_ops.while_loop(c, b, i); | |||||
| } | } | ||||
| private void _testWhileContextHelper(int maximum_iterations) | private void _testWhileContextHelper(int maximum_iterations) | ||||
| @@ -2,15 +2,14 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using NumSharp; | using NumSharp; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.UnitTest; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest.img_test | namespace TensorFlowNET.UnitTest.img_test | ||||
| { | { | ||||
| [Ignore] | |||||
| [TestClass] | [TestClass] | ||||
| public class TestCrop | |||||
| public class TestCrop : GraphModeTestBase | |||||
| { | { | ||||
| [TestMethod] | [TestMethod] | ||||
| public void TestCropAndResize() | public void TestCropAndResize() | ||||
| { | { | ||||
| @@ -3,13 +3,13 @@ using FluentAssertions; | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using NumSharp; | using NumSharp; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.UnitTest; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest.layers_test | namespace TensorFlowNET.UnitTest.layers_test | ||||
| { | { | ||||
| [Ignore] | |||||
| [TestClass] | [TestClass] | ||||
| public class flatten | |||||
| public class flatten : GraphModeTestBase | |||||
| { | { | ||||
| [TestMethod] | [TestMethod] | ||||
| public void Case1() | public void Case1() | ||||
| @@ -3,6 +3,7 @@ using System.Linq; | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using Tensorflow.UnitTest; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest.ops_test | namespace TensorFlowNET.UnitTest.ops_test | ||||
| @@ -10,9 +11,8 @@ namespace TensorFlowNET.UnitTest.ops_test | |||||
| /// <summary> | /// <summary> | ||||
| /// excerpt of tensorflow/python/framework/ops_test.py | /// excerpt of tensorflow/python/framework/ops_test.py | ||||
| /// </summary> | /// </summary> | ||||
| [Ignore] | |||||
| [TestClass] | [TestClass] | ||||
| public class ControlDependenciesTest : PythonTest | |||||
| public class ControlDependenciesTest : GraphModeTestBase | |||||
| { | { | ||||
| [TestMethod] | [TestMethod] | ||||
| public void TestBasic() | public void TestBasic() | ||||
| @@ -35,72 +35,6 @@ namespace TensorFlowNET.UnitTest.ops_test | |||||
| Assert.AreEqual(0, e.op.control_inputs.Length); | Assert.AreEqual(0, e.op.control_inputs.Length); | ||||
| } | } | ||||
| [Ignore("Future is not supported yet")] | |||||
| [TestMethod] | |||||
| public void TestEager() | |||||
| { | |||||
| Tensor a = null, c = null; | |||||
| object b = null; | |||||
| var calls = 0; | |||||
| Func<Tensor> future = () => | |||||
| { | |||||
| calls += 1; | |||||
| return constant_op.constant(2.0); | |||||
| }; | |||||
| using (var opts = new ContextOptions()) | |||||
| using (var status = new Status()) | |||||
| using (var context = new Context(opts, status)) | |||||
| { | |||||
| if (context.executing_eagerly()) | |||||
| { | |||||
| // TODO: make this compile (see original Python code below) | |||||
| a = constant_op.constant(1.0); | |||||
| b = future; // <--- {henon} obviously, this doesn't compile, looks like control_dependencies needs to be able to take callables as well. | |||||
| tf_with(ops.control_dependencies(new object[] { a, b }), ctrl => | |||||
| { | |||||
| return c = constant_op.constant(3.0); | |||||
| }); | |||||
| Assert.AreEqual(calls, 1); | |||||
| } | |||||
| else | |||||
| { | |||||
| var g = tf.Graph().as_default(); | |||||
| a = constant_op.constant(1.0); | |||||
| var b1 = future(); | |||||
| tf_with(g.control_dependencies(new[] { a, b }), ctrl => | |||||
| { | |||||
| c = constant_op.constant(3.0); | |||||
| }); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(c.op.control_inputs, new[] { a.op, b1.op })); | |||||
| Assert.AreEqual(1, calls); | |||||
| } | |||||
| } | |||||
| /* | |||||
| def testEager(self): | |||||
| def future(): | |||||
| future.calls += 1 | |||||
| return constant_op.constant(2.0) | |||||
| future.calls = 0 | |||||
| if context.executing_eagerly(): | |||||
| a = constant_op.constant(1.0) | |||||
| b = future | |||||
| with ops.control_dependencies([a, b]): | |||||
| c = constant_op.constant(3.0) | |||||
| self.assertEqual(future.calls, 1) | |||||
| else: | |||||
| g = ops.Graph() | |||||
| with g.as_default(): | |||||
| a = constant_op.constant(1.0) | |||||
| b = future() | |||||
| with g.control_dependencies([a, b]): | |||||
| c = constant_op.constant(3.0) | |||||
| self.assertEqual(c.op.control_inputs, [a.op, b.op]) | |||||
| self.assertEqual(future.calls, 1) | |||||
| */ | |||||
| } | |||||
| [Ignore("How to port the ConvertibleObj?")] | [Ignore("How to port the ConvertibleObj?")] | ||||
| [TestMethod] | [TestMethod] | ||||
| public void TestBasicWithConversion() | public void TestBasicWithConversion() | ||||
| @@ -28,7 +28,7 @@ namespace TensorFlowNET.UnitTest.ops_test | |||||
| using (var g = tf.Graph().as_default()) | using (var g = tf.Graph().as_default()) | ||||
| { | { | ||||
| var x = constant_op.constant(new[,] {{1, 2, 3}, {4, 5, 6}}); | var x = constant_op.constant(new[,] {{1, 2, 3}, {4, 5, 6}}); | ||||
| var c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]); | |||||
| var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]); | |||||
| var op = g._create_op_from_tf_operation(c_op); | var op = g._create_op_from_tf_operation(c_op); | ||||
| Assert.AreEqual("myop", op.name); | Assert.AreEqual("myop", op.name); | ||||