diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
index e9c01aa3..ddf7814f 100644
--- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs
+++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
@@ -116,10 +116,19 @@ namespace Tensorflow
case int intVal:
nparray = intVal;
break;
+ case int[] intVals:
+ nparray = np.array(intVals);
+ break;
+ case int[,] intVals:
+ nparray = np.array(intVals);
+ break;
case long intVal:
nparray = intVal;
break;
- case int[] intVals:
+ case long[] intVals:
+ nparray = np.array(intVals);
+ break;
+ case long[,] intVals:
nparray = np.array(intVals);
break;
case float floatVal:
@@ -128,9 +137,18 @@ namespace Tensorflow
case float[] floatVals:
nparray = floatVals;
break;
+ case float[,] floatVals:
+ nparray = np.array(floatVals);
+ break;
case double doubleVal:
nparray = doubleVal;
break;
+ case double[] doubleVals:
+ nparray = np.array(doubleVals);
+ break;
+ case double[,] doubleVals:
+ nparray = np.array(doubleVals);
+ break;
case string strVal:
nparray = strVal;
break;
@@ -140,8 +158,11 @@ namespace Tensorflow
case byte[] byteValues:
nparray = byteValues;
break;
+ case byte[,] byteValues:
+ nparray = np.array(byteValues);
+ break;
default:
- throw new NotImplementedException("make_tensor_proto Not Implemented");
+ throw new NotImplementedException($"make_tensor_proto: Support for type {values.GetType()} Not Implemented");
}
}
else
@@ -174,7 +195,7 @@ namespace Tensorflow
nparray = Convert.ToString(values);
break;
default:
- throw new NotImplementedException("make_tensor_proto Not Implemented");
+ throw new NotImplementedException($"make_tensor_proto: Support for type {np_dt.Name} Not Implemented");
}
}
}
diff --git a/test/TensorFlowNET.UnitTest/CreateOpFromTfOperationTest.cs b/test/TensorFlowNET.UnitTest/CreateOpFromTfOperationTest.cs
new file mode 100644
index 00000000..ce922193
--- /dev/null
+++ b/test/TensorFlowNET.UnitTest/CreateOpFromTfOperationTest.cs
@@ -0,0 +1,168 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Tensorflow;
+
+namespace TensorFlowNET.UnitTest
+{
+ ///
+ /// excerpt of tensorflow/python/framework/ops_test.py
+ /// # These cases test the private Graph._create_op_from_tf_operation
+ /// # method. Arguably we should only test the public APIs that depend on this
+ /// # method. However, this logic is complex and tricky, and it can be difficult to
+ /// # ascertain if we have adequate coverage (e.g. a graph may run successfully if
+ /// # the control flow context isn't set properly, but a more complicated use case
+ /// # that might not be obvious to test will fail). Thus we instead explicitly test
+ /// # the low-level behavior.
+ ///
+ [TestClass]
+ public class CreateOpFromTfOperationTest : PythonTest
+ {
+
+ [TestMethod]
+ public void TestShape()
+ {
+ var graph = tf.Graph().as_default();
+ with(graph, g =>
+ {
+ var x = constant_op.constant(new [,] { {1, 2, 3}, {4, 5, 6}});
+ var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]);
+ var op = g._create_op_from_tf_operation(c_op);
+
+ Assert.AreEqual("myop", op.name);
+ Assert.AreEqual("Identity", op.type);
+ Assert.AreEqual(1, len(op.outputs));
+ AssertItemsEqual(new []{2, 3}, op.outputs[0].shape);
+ });
+ }
+
+ /*def testUniqueName(self):
+ g = ops.Graph()
+ with g.as_default():
+ c_op = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop"), [], [])
+ c_op2 = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop_1"), [], [])
+ op = g._create_op_from_tf_operation(c_op)
+ op2 = g._create_op_from_tf_operation(c_op2)
+
+ # Create ops with same names as op1 and op2. We expect the new names to be
+ # uniquified.
+ op3 = test_ops.int_output(name="myop").op
+ op4 = test_ops.int_output(name="myop_1").op
+
+ self.assertEqual(op.name, "myop")
+ self.assertEqual(op2.name, "myop_1")
+ self.assertEqual(op3.name, "myop_2")
+ self.assertEqual(op4.name, "myop_1_1")
+
+ @test_util.run_v1_only("b/120545219")
+ def testCond(self):
+ g = ops.Graph()
+ with g.as_default():
+ x = test_ops.int_output()
+
+ def true_fn():
+ ops._create_c_op(ops.get_default_graph(),
+ ops._NodeDef("IntInput", "cond/myop"), [x], [])
+ new_ops = g._add_new_tf_operations()
+ self.assertEqual(len(new_ops), 1)
+ return x
+
+ control_flow_ops.cond(x < 10, true_fn, lambda: x)
+
+ op = g.get_operation_by_name("cond/myop")
+ self.assertIsNotNone(op)
+ self.assertEqual(op.name, "cond/myop")
+ self.assertEqual(op.type, "IntInput")
+ self.assertEqual(op.outputs, [])
+ op_input = op.inputs[0].op
+ self.assertEqual(op_input.type, "Switch")
+ self.assertEqual(op_input.inputs[0], x)
+ self.assertEqual(op.graph, g)
+ # pylint: disable=protected-access
+ self.assertIsNotNone(op._get_control_flow_context())
+ self.assertEqual(op._get_control_flow_context().name,
+ "cond/cond_text")
+ # pylint: enable=protected-access
+
+ @test_util.run_v1_only("b/120545219")
+ def testWhileLoop(self):
+ g = ops.Graph()
+ with g.as_default():
+ x = test_ops.int_output()
+
+ def body(i):
+ ops._create_c_op(ops.get_default_graph(),
+ ops._NodeDef("IntInput", "myloop/myop"), [x], [])
+ new_ops = g._add_new_tf_operations()
+ self.assertEqual(len(new_ops), 1)
+ return i
+
+ control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
+
+ op = g.get_operation_by_name("myloop/myop")
+ self.assertIsNotNone(op)
+ self.assertEqual(op.name, "myloop/myop")
+ self.assertEqual(op.type, "IntInput")
+ self.assertEqual(op.outputs, [])
+ op_input = op.inputs[0].op
+ self.assertEqual(op_input.type, "Enter")
+ self.assertEqual(list(op_input.inputs), [x])
+ self.assertEqual(op.graph, g)
+ # pylint: disable=protected-access
+ self.assertIsNotNone(op._get_control_flow_context())
+ self.assertEqual(op._get_control_flow_context().name,
+ "myloop/while_context")
+ # pylint: enable=protected-access
+
+ @test_util.run_v1_only("b/120545219")
+ def testWhileLoopWithInternalControlDep(self):
+ g = ops.Graph()
+ with g.as_default():
+ x = test_ops.int_output()
+
+ def body(i):
+ c = constant_op.constant(1.0, name="c")
+ ops._create_c_op(ops.get_default_graph(),
+ ops._NodeDef("IntInput", "myloop/myop"), [x], [])
+ with ops.control_dependencies([c]):
+ new_ops = g._add_new_tf_operations()
+ self.assertEqual(len(new_ops), 1)
+ return i
+
+ control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
+
+ op = g.get_operation_by_name("myloop/myop")
+ self.assertIsNotNone(op)
+ c = g.get_operation_by_name("myloop/c")
+ self.assertIsNotNone(c)
+ # Internal control dep is preserved
+ self.assertEqual(op.control_inputs, [c])
+
+ @test_util.run_v1_only("b/120545219")
+ def testWhileLoopWithExternalControlDep(self):
+ g = ops.Graph()
+ with g.as_default():
+ x = test_ops.int_output()
+ c = constant_op.constant(1.0)
+
+ def body(i):
+ ops._create_c_op(ops.get_default_graph(),
+ ops._NodeDef("IntInput", "myloop/myop"), [x], [])
+ with ops.control_dependencies([c]):
+ new_ops = g._add_new_tf_operations()
+ self.assertEqual(len(new_ops), 1)
+ return i
+
+ control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
+
+ op = g.get_operation_by_name("myloop/myop")
+ self.assertIsNotNone(op)
+ # External control dep is removed and replaced with internal control dep
+ self.assertNotEqual(op.control_inputs[0], c.op)
+ self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context())
+
+
+ */
+ }
+ }