diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 284f9ee6..3cfbe4ea 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -38,6 +38,10 @@ namespace Tensorflow return; _handle = handle; + + _outputs = new Tensor[NumOutputs]; + for (int i = 0; i < NumOutputs; i++) + _outputs[i] = new Tensor(this, i, OutputType(i)); } public Operation(Graph g, string opType, string oper_name) @@ -99,7 +103,6 @@ namespace Tensorflow // Initialize self._outputs. output_types = new TF_DataType[NumOutputs]; - for (int i = 0; i < NumOutputs; i++) output_types[i] = OutputType(i); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs index 5f1860c6..a68f28c1 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs @@ -16,6 +16,11 @@ namespace Tensorflow return constant_op.Constant(scalar); } + public static implicit operator int(Tensor tensor) + { + return tensor.Data()[0]; + } + public static implicit operator IntPtr(Tensor tensor) { return tensor._handle; diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs b/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs index 8345d0b7..5e7fe2a5 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.Implicit.cs @@ -23,12 +23,6 @@ namespace Tensorflow public static implicit operator RefVariable(Tensor var) { - switch (var.dtype) - { - case TF_DataType.TF_INT32: - return tf.Variable(var.Data()[0]); - } - return null; } } diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs index 95313d4c..3158eeaf 100644 --- a/test/TensorFlowNET.UnitTest/VariableTest.cs +++ b/test/TensorFlowNET.UnitTest/VariableTest.cs @@ -43,7 +43,8 @@ namespace TensorFlowNET.UnitTest [TestMethod] public void Add() { - var x = tf.Variable(10, name: "x"); + int result = 0; + Tensor x = tf.Variable(10, name: "x"); var model = tf.global_variables_initializer(); using (var session = tf.Session()) @@ -52,10 +53,12 @@ namespace TensorFlowNET.UnitTest for(int i = 0; i < 5; i++) { x = x + 1; - var result = session.run(x); + result = session.run(x); print(result); } } + + Assert.AreEqual(15, result); } } }