From fe4a06fed301d04215ef936e6555396409bf0417 Mon Sep 17 00:00:00 2001 From: Meinrad Recheis Date: Tue, 9 Apr 2019 18:33:26 +0200 Subject: [PATCH] make_tensor_proto: supported additional types int[,] long[] long[,] float[,] double[] double[,] and byte[,] --- src/TensorFlowNET.Core/Tensors/tensor_util.cs | 27 ++- .../CreateOpFromTfOperationTest.cs | 168 ++++++++++++++++++ 2 files changed, 192 insertions(+), 3 deletions(-) create mode 100644 test/TensorFlowNET.UnitTest/CreateOpFromTfOperationTest.cs diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index e9c01aa3..ddf7814f 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -116,10 +116,19 @@ namespace Tensorflow case int intVal: nparray = intVal; break; + case int[] intVals: + nparray = np.array(intVals); + break; + case int[,] intVals: + nparray = np.array(intVals); + break; case long intVal: nparray = intVal; break; - case int[] intVals: + case long[] intVals: + nparray = np.array(intVals); + break; + case long[,] intVals: nparray = np.array(intVals); break; case float floatVal: @@ -128,9 +137,18 @@ namespace Tensorflow case float[] floatVals: nparray = floatVals; break; + case float[,] floatVals: + nparray = np.array(floatVals); + break; case double doubleVal: nparray = doubleVal; break; + case double[] doubleVals: + nparray = np.array(doubleVals); + break; + case double[,] doubleVals: + nparray = np.array(doubleVals); + break; case string strVal: nparray = strVal; break; @@ -140,8 +158,11 @@ namespace Tensorflow case byte[] byteValues: nparray = byteValues; break; + case byte[,] byteValues: + nparray = np.array(byteValues); + break; default: - throw new NotImplementedException("make_tensor_proto Not Implemented"); + throw new NotImplementedException($"make_tensor_proto: Support for type {values.GetType()} Not Implemented"); } } else @@ -174,7 +195,7 @@ namespace Tensorflow nparray = Convert.ToString(values); break; default: - throw new NotImplementedException("make_tensor_proto Not Implemented"); + throw new NotImplementedException($"make_tensor_proto: Support for type {np_dt.Name} Not Implemented"); } } } diff --git a/test/TensorFlowNET.UnitTest/CreateOpFromTfOperationTest.cs b/test/TensorFlowNET.UnitTest/CreateOpFromTfOperationTest.cs new file mode 100644 index 00000000..ce922193 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/CreateOpFromTfOperationTest.cs @@ -0,0 +1,168 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow; + +namespace TensorFlowNET.UnitTest +{ + /// + /// excerpt of tensorflow/python/framework/ops_test.py + /// # These cases test the private Graph._create_op_from_tf_operation + /// # method. Arguably we should only test the public APIs that depend on this + /// # method. However, this logic is complex and tricky, and it can be difficult to + /// # ascertain if we have adequate coverage (e.g. a graph may run successfully if + /// # the control flow context isn't set properly, but a more complicated use case + /// # that might not be obvious to test will fail). Thus we instead explicitly test + /// # the low-level behavior. + /// + [TestClass] + public class CreateOpFromTfOperationTest : PythonTest + { + + [TestMethod] + public void TestShape() + { + var graph = tf.Graph().as_default(); + with(graph, g => + { + var x = constant_op.constant(new [,] { {1, 2, 3}, {4, 5, 6}}); + var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]); + var op = g._create_op_from_tf_operation(c_op); + + Assert.AreEqual("myop", op.name); + Assert.AreEqual("Identity", op.type); + Assert.AreEqual(1, len(op.outputs)); + AssertItemsEqual(new []{2, 3}, op.outputs[0].shape); + }); + } + + /*def testUniqueName(self): + g = ops.Graph() + with g.as_default(): + c_op = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop"), [], []) + c_op2 = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop_1"), [], []) + op = g._create_op_from_tf_operation(c_op) + op2 = g._create_op_from_tf_operation(c_op2) + + # Create ops with same names as op1 and op2. We expect the new names to be + # uniquified. + op3 = test_ops.int_output(name="myop").op + op4 = test_ops.int_output(name="myop_1").op + + self.assertEqual(op.name, "myop") + self.assertEqual(op2.name, "myop_1") + self.assertEqual(op3.name, "myop_2") + self.assertEqual(op4.name, "myop_1_1") + + @test_util.run_v1_only("b/120545219") + def testCond(self): + g = ops.Graph() + with g.as_default(): + x = test_ops.int_output() + + def true_fn(): + ops._create_c_op(ops.get_default_graph(), + ops._NodeDef("IntInput", "cond/myop"), [x], []) + new_ops = g._add_new_tf_operations() + self.assertEqual(len(new_ops), 1) + return x + + control_flow_ops.cond(x < 10, true_fn, lambda: x) + + op = g.get_operation_by_name("cond/myop") + self.assertIsNotNone(op) + self.assertEqual(op.name, "cond/myop") + self.assertEqual(op.type, "IntInput") + self.assertEqual(op.outputs, []) + op_input = op.inputs[0].op + self.assertEqual(op_input.type, "Switch") + self.assertEqual(op_input.inputs[0], x) + self.assertEqual(op.graph, g) + # pylint: disable=protected-access + self.assertIsNotNone(op._get_control_flow_context()) + self.assertEqual(op._get_control_flow_context().name, + "cond/cond_text") + # pylint: enable=protected-access + + @test_util.run_v1_only("b/120545219") + def testWhileLoop(self): + g = ops.Graph() + with g.as_default(): + x = test_ops.int_output() + + def body(i): + ops._create_c_op(ops.get_default_graph(), + ops._NodeDef("IntInput", "myloop/myop"), [x], []) + new_ops = g._add_new_tf_operations() + self.assertEqual(len(new_ops), 1) + return i + + control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") + + op = g.get_operation_by_name("myloop/myop") + self.assertIsNotNone(op) + self.assertEqual(op.name, "myloop/myop") + self.assertEqual(op.type, "IntInput") + self.assertEqual(op.outputs, []) + op_input = op.inputs[0].op + self.assertEqual(op_input.type, "Enter") + self.assertEqual(list(op_input.inputs), [x]) + self.assertEqual(op.graph, g) + # pylint: disable=protected-access + self.assertIsNotNone(op._get_control_flow_context()) + self.assertEqual(op._get_control_flow_context().name, + "myloop/while_context") + # pylint: enable=protected-access + + @test_util.run_v1_only("b/120545219") + def testWhileLoopWithInternalControlDep(self): + g = ops.Graph() + with g.as_default(): + x = test_ops.int_output() + + def body(i): + c = constant_op.constant(1.0, name="c") + ops._create_c_op(ops.get_default_graph(), + ops._NodeDef("IntInput", "myloop/myop"), [x], []) + with ops.control_dependencies([c]): + new_ops = g._add_new_tf_operations() + self.assertEqual(len(new_ops), 1) + return i + + control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") + + op = g.get_operation_by_name("myloop/myop") + self.assertIsNotNone(op) + c = g.get_operation_by_name("myloop/c") + self.assertIsNotNone(c) + # Internal control dep is preserved + self.assertEqual(op.control_inputs, [c]) + + @test_util.run_v1_only("b/120545219") + def testWhileLoopWithExternalControlDep(self): + g = ops.Graph() + with g.as_default(): + x = test_ops.int_output() + c = constant_op.constant(1.0) + + def body(i): + ops._create_c_op(ops.get_default_graph(), + ops._NodeDef("IntInput", "myloop/myop"), [x], []) + with ops.control_dependencies([c]): + new_ops = g._add_new_tf_operations() + self.assertEqual(len(new_ops), 1) + return i + + control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") + + op = g.get_operation_by_name("myloop/myop") + self.assertIsNotNone(op) + # External control dep is removed and replaced with internal control dep + self.assertNotEqual(op.control_inputs[0], c.op) + self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context()) + + + */ + } + }